Browse Source

feat: switch from pylint to ruff (#322)

* switch from pylint to ruff

* fix file names

* add codespell to ruff

* formatting with ruff

* configure codespell

* that was a lot

* i hope that was all of it

* yapf

* fix tests

* why did that break yapf

* let's see if the ignore works properly

* nope

* codespell

* seriously

* fix stablelm (why was it broken?)

* uhh?
AlpinDale 1 year ago
parent
commit
e42a78381a
85 changed files with 2327 additions and 2369 deletions
  1. 8 5
      .github/workflows/ruff.yml
  2. 2 1
      .gitignore
  3. 0 452
      .pylintrc
  4. 1 1
      README.md
  5. 13 12
      aphrodite/common/gguf.py
  6. 2 1
      aphrodite/common/logger.py
  7. 1 0
      aphrodite/common/outputs.py
  8. 6 6
      aphrodite/common/sampling_params.py
  9. 2 1
      aphrodite/common/sequence.py
  10. 3 3
      aphrodite/common/utils.py
  11. 5 4
      aphrodite/endpoints/llm.py
  12. 14 12
      aphrodite/endpoints/openai/api_server.py
  13. 2 1
      aphrodite/endpoints/openai/protocol.py
  14. 9 5
      aphrodite/endpoints/openai/serving_chat.py
  15. 16 10
      aphrodite/endpoints/openai/serving_completions.py
  16. 6 4
      aphrodite/endpoints/openai/serving_engine.py
  17. 5 2
      aphrodite/engine/aphrodite_engine.py
  18. 10 5
      aphrodite/engine/metrics.py
  19. 2 2
      aphrodite/kv_quant/observer.py
  20. 6 4
      aphrodite/lora/layers.py
  21. 3 2
      aphrodite/lora/models.py
  22. 1 1
      aphrodite/lora/punica.py
  23. 3 3
      aphrodite/lora/worker_manager.py
  24. 2 1
      aphrodite/modeling/hf_downloader.py
  25. 10 5
      aphrodite/modeling/layers/linear.py
  26. 1 1
      aphrodite/modeling/layers/quantization/__init__.py
  27. 6 3
      aphrodite/modeling/layers/quantization/aqlm.py
  28. 4 3
      aphrodite/modeling/layers/quantization/awq.py
  29. 4 3
      aphrodite/modeling/layers/quantization/bitsandbytes.py
  30. 2 1
      aphrodite/modeling/layers/quantization/exl2.py
  31. 24 14
      aphrodite/modeling/layers/quantization/marlin.py
  32. 2 1
      aphrodite/modeling/layers/quantization/quip.py
  33. 2 3
      aphrodite/modeling/layers/quantization/quip_utils.py
  34. 2 1
      aphrodite/modeling/layers/quantization/squeezellm.py
  35. 2 2
      aphrodite/modeling/layers/rotary_embedding.py
  36. 3 3
      aphrodite/modeling/layers/sampler.py
  37. 6 5
      aphrodite/modeling/layers/triton_kernel/fused_moe.py
  38. 1 1
      aphrodite/modeling/layers/triton_kernel/prefix_prefill.py
  39. 2 2
      aphrodite/modeling/megatron/custom_all_reduce.py
  40. 1 1
      aphrodite/modeling/megatron/parallel_state.py
  41. 1 1
      aphrodite/modeling/models/__init__.py
  42. 82 54
      aphrodite/modeling/models/baichuan.py
  43. 78 53
      aphrodite/modeling/models/deepseek.py
  44. 86 59
      aphrodite/modeling/models/gemma.py
  45. 53 35
      aphrodite/modeling/models/gpt_j.py
  46. 25 15
      aphrodite/modeling/models/gpt_neox.py
  47. 60 41
      aphrodite/modeling/models/internlm2.py
  48. 77 52
      aphrodite/modeling/models/llama.py
  49. 129 84
      aphrodite/modeling/models/mixtral.py
  50. 94 64
      aphrodite/modeling/models/mixtral_quant.py
  51. 45 30
      aphrodite/modeling/models/olmo.py
  52. 56 36
      aphrodite/modeling/models/opt.py
  53. 80 52
      aphrodite/modeling/models/phi.py
  54. 68 46
      aphrodite/modeling/models/qwen.py
  55. 86 59
      aphrodite/modeling/models/qwen2.py
  56. 105 74
      aphrodite/modeling/models/stablelm.py
  57. 9 4
      aphrodite/modeling/outlines_decoding.py
  58. 14 9
      aphrodite/modeling/outlines_logits_processors.py
  59. 45 27
      aphrodite/processing/block_manager.py
  60. 3 3
      aphrodite/processing/evictor.py
  61. 27 15
      aphrodite/processing/scheduler.py
  62. 208 135
      aphrodite/task_handler/model_runner.py
  63. 2 1
      aphrodite/task_handler/worker.py
  64. 108 151
      aphrodite/transformers_utils/configs/mpt.py
  65. 2 1
      aphrodite/transformers_utils/configs/olmo.py
  66. 22 11
      aphrodite/transformers_utils/tokenizers/baichuan.py
  67. 145 102
      env.py
  68. 89 39
      examples/gguf_to_torch.py
  69. 19 10
      examples/gradio_server.py
  70. 42 12
      examples/marlin/convert.py
  71. 4 3
      examples/offline_inference.py
  72. 44 24
      examples/slora_inference.py
  73. 89 9
      formatting.sh
  74. 45 0
      pyproject.toml
  75. 15 2
      requirements-dev.txt
  76. 56 36
      setup.py
  77. 0 50
      tests/async_engine/api_server_async_aphrodite.py
  78. 0 93
      tests/async_engine/test_api_server.py
  79. 0 91
      tests/async_engine/test_async_aphrodite.py
  80. 0 119
      tests/async_engine/test_openai_server.py
  81. 0 65
      tests/async_engine/test_request_tracker.py
  82. 6 3
      tests/benchmarks/backend_request_func.py
  83. 4 3
      tests/benchmarks/serving.py
  84. 9 7
      tests/endpoints/test_openai_server.py
  85. 1 2
      tests/engine/test_detokenize.py

+ 8 - 5
.github/workflows/pylint.yml → .github/workflows/ruff.yml

@@ -1,4 +1,4 @@
-name: pylint
+name: ruff
 
 on:
   # Trigger the workflow on push or pull request,
@@ -13,7 +13,7 @@ on:
       - dev
 
 jobs:
-  pylint:
+  ruff:
     runs-on: ubuntu-latest
     strategy:
       matrix:
@@ -27,7 +27,10 @@ jobs:
     - name: Install dependencies
       run: |
         python -m pip install --upgrade pip
-        pip install pylint==2.8.2
-    - name: Analysing the code with pylint
+        pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1
+    - name: Analysing the code with ruff
       run: |
-        pylint aphrodite tests
+        ruff aphrodite tests
+    - name: Spelling check with codespell
+      run: |
+         codespell --toml pyproject.toml

+ 2 - 1
.gitignore

@@ -198,4 +198,5 @@ _build/
 *_hip*
 
 kv_cache_states/*
-quant_params/*
+quant_params/*
+.ruff_cache/

+ 0 - 452
.pylintrc

@@ -1,452 +0,0 @@
-# This Pylint rcfile contains a best-effort configuration to uphold the
-# best-practices and style described in the Google Python style guide:
-#   https://google.github.io/styleguide/pyguide.html
-#
-# Its canonical open-source location is:
-#   https://google.github.io/styleguide/pylintrc
-
-[MASTER]
-
-# Files or directories to be skipped. They should be base names, not paths.
-ignore=docs
-
-# Files or directories matching the regex patterns are skipped. The regex
-# matches against base names, not paths.
-ignore-patterns=
-
-# Pickle collected data for later comparisons.
-persistent=no
-
-# List of plugins (as comma separated values of python modules names) to load,
-# usually to register additional checkers.
-load-plugins=
-
-# Use multiple processes to speed up Pylint.
-jobs=4
-
-# Allow loading of arbitrary C extensions. Extensions are imported into the
-# active Python interpreter and may run arbitrary code.
-unsafe-load-any-extension=no
-
-
-[MESSAGES CONTROL]
-
-# Only show warnings with the listed confidence levels. Leave empty to show
-# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
-confidence=
-
-# Enable the message, report, category or checker with the given id(s). You can
-# either give multiple identifier separated by comma (,) or put this option
-# multiple time (only on the command line, not in the configuration file where
-# it should appear only once). See also the "--disable" option for examples.
-#enable=
-
-# Disable the message, report, category or checker with the given id(s). You
-# can either give multiple identifiers separated by comma (,) or put this
-# option multiple times (only on the command line, not in the configuration
-# file where it should appear only once).You can also use "--disable=all" to
-# disable everything first and then reenable specific checks. For example, if
-# you want to run only the similarities checker, you can use "--disable=all
-# --enable=similarities". If you want to run only the classes checker, but have
-# no Warning level messages displayed, use"--disable=all --enable=classes
-# --disable=W"
-disable=abstract-method,
-        consider-using-with,
-        consider-using-in,
-        invalid-overriden-method,
-        apply-builtin,
-        unnecessary-comprehension,
-        arguments-differ,
-        attribute-defined-outside-init,
-        backtick,
-        bad-option-value,
-        basestring-builtin,
-        buffer-builtin,
-        c-extension-no-member,
-        consider-using-enumerate,
-        cmp-builtin,
-        inconsistent-quotes,
-        cmp-method,
-        coerce-builtin,
-        coerce-method,
-        dangerous-default-value,
-        delslice-method,
-        div-method,
-        duplicate-code,
-        eq-without-hash,
-        execfile-builtin,
-        file-builtin,
-        filter-builtin-not-iterating,
-        fixme,
-        getslice-method,
-        global-statement,
-        hex-method,
-        idiv-method,
-        implicit-str-concat-in-sequence,
-        import-error,
-        import-self,
-        import-star-module-level,
-        import-outside-toplevel,
-        use-a-generator,
-        unused-argument,
-        inconsistent-return-statements,
-        consider-using-get,
-        input-builtin,
-        intern-builtin,
-        invalid-str-codec,
-        invalid-name,
-        locally-disabled,
-        broad-except,
-        redefined-outer-name,
-        logging-fstring-interpolation,
-        logging-not-lazy,
-        unsubscriptable-object,
-        long-builtin,
-        long-suffix,
-        line-too-long,
-        map-builtin-not-iterating,
-        misplaced-comparison-constant,
-        missing-class-docstring,
-        missing-function-docstring,
-        missing-module-docstring,
-        metaclass-assignment,
-        next-method-called,
-        next-method-defined,
-        no-absolute-import,
-        no-else-break,
-        no-else-continue,
-        no-else-raise,
-        no-else-return,
-        no-init,
-        no-member,
-        no-name-in-module,
-        no-self-use,
-        nonzero-method,
-        oct-method,
-        old-division,
-        old-ne-operator,
-        old-octal-literal,
-        old-raise-syntax,
-        parameter-unpacking,
-        print-statement,
-        protected-access,
-        raising-string,
-        range-builtin-not-iterating,
-        raw_input-builtin,
-        rdiv-method,
-        reduce-builtin,
-        relative-import,
-        reload-builtin,
-        round-builtin,
-        setslice-method,
-        signature-differs,
-        standarderror-builtin,
-        suppressed-message,
-        super-init-not-called,
-        sys-max-int,
-        too-few-public-methods,
-        too-many-ancestors,
-        too-many-arguments,
-        too-many-boolean-expressions,
-        too-many-branches,
-        too-many-instance-attributes,
-        too-many-locals,
-        too-many-nested-blocks,
-        too-many-public-methods,
-        too-many-return-statements,
-        too-many-statements,
-        trailing-newlines,
-        unichr-builtin,
-        unicode-builtin,
-        unnecessary-pass,
-        unpacking-in-except,
-        unspecified-encoding,
-        useless-else-on-loop,
-        useless-object-inheritance,
-        useless-suppression,
-        useless-return,
-        using-cmp-argument,
-        wrong-import-order,
-        xrange-builtin,
-        zip-builtin-not-iterating,
-
-
-[REPORTS]
-
-# Set the output format. Available formats are text, parseable, colorized, msvs
-# (visual studio) and html. You can also give a reporter class, eg
-# mypackage.mymodule.MyReporterClass.
-output-format=text
-
-# Tells whether to display a full report or only the messages
-reports=no
-
-# Python expression which should return a note less than 10 (10 is the highest
-# note). You have access to the variables errors warning, statement which
-# respectively contain the number of errors / warnings messages and the total
-# number of statements analyzed. This is used by the global evaluation report
-# (RP0004).
-evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
-
-# Template used to display messages. This is a python new-style format string
-# used to format the message information. See doc for all details
-#msg-template=
-
-
-[BASIC]
-
-# Good variable names which should always be accepted, separated by a comma
-good-names=main,_
-
-# Bad variable names which should always be refused, separated by a comma
-bad-names=
-
-# Colon-delimited sets of names that determine each other's naming style when
-# the name regexes allow several styles.
-name-group=
-
-# Include a hint for the correct naming format with invalid-name
-include-naming-hint=no
-
-# List of decorators that produce properties, such as abc.abstractproperty. Add
-# to this list to register other decorators that produce valid properties.
-property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
-
-# Regular expression matching correct function names
-function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
-
-# Regular expression matching correct variable names
-variable-rgx=^[a-z][a-z0-9_]*$
-
-# Regular expression matching correct constant names
-const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
-
-# Regular expression matching correct attribute names
-attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
-
-# Regular expression matching correct argument names
-argument-rgx=^[a-z][a-z0-9_]*$
-
-# Regular expression matching correct class attribute names
-class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
-
-# Regular expression matching correct inline iteration names
-inlinevar-rgx=^[a-z][a-z0-9_]*$
-
-# Regular expression matching correct class names
-class-rgx=^_?[A-Z][a-zA-Z0-9]*$
-
-# Regular expression matching correct module names
-module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
-
-# Regular expression matching correct method names
-method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
-
-# Regular expression which should only match function or class names that do
-# not require a docstring.
-no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
-
-# Minimum line length for functions/classes that require docstrings, shorter
-# ones are exempt.
-docstring-min-length=10
-
-
-[TYPECHECK]
-
-# List of decorators that produce context managers, such as
-# contextlib.contextmanager. Add to this list to register other decorators that
-# produce valid context managers.
-contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
-
-# Tells whether missing members accessed in mixin class should be ignored. A
-# mixin class is detected if its name ends with "mixin" (case insensitive).
-ignore-mixin-members=yes
-
-# List of module names for which member attributes should not be checked
-# (useful for modules/projects where namespaces are manipulated during runtime
-# and thus existing member attributes cannot be deduced by static analysis. It
-# supports qualified module names, as well as Unix pattern matching.
-ignored-modules=
-
-# List of class names for which member attributes should not be checked (useful
-# for classes with dynamically set attributes). This supports the use of
-# qualified names.
-ignored-classes=optparse.Values,thread._local,_thread._local
-
-# List of members which are set dynamically and missed by pylint inference
-# system, and so shouldn't trigger E1101 when accessed. Python regular
-# expressions are accepted.
-generated-members=
-
-
-[FORMAT]
-
-# Maximum number of characters on a single line.
-max-line-length=80
-
-# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
-# lines made too long by directives to pytype.
-
-# Regexp for a line that is allowed to be longer than the limit.
-ignore-long-lines=(?x)(
-  ^\s*(\#\ )?<?https?://\S+>?$|
-  ^\s*(from\s+\S+\s+)?import\s+.+$)
-
-# Allow the body of an if to be on the same line as the test if there is no
-# else.
-single-line-if-stmt=yes
-
-# Maximum number of lines in a module
-max-module-lines=99999
-
-# String used as indentation unit.  The internal Google style guide mandates 2
-# spaces.  Google's externaly-published style guide says 4, consistent with
-# PEP 8.  Here, we use 2 spaces, for conformity with many open-sourced Google
-# projects (like TensorFlow).
-indent-string='    '
-
-# Number of spaces of indent required inside a hanging  or continued line.
-indent-after-paren=4
-
-# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
-expected-line-ending-format=
-
-
-[MISCELLANEOUS]
-
-# List of note tags to take in consideration, separated by a comma.
-notes=TODO
-
-
-[STRING]
-
-# This flag controls whether inconsistent-quotes generates a warning when the
-# character used as a quote delimiter is used inconsistently within a module.
-check-quote-consistency=yes
-
-
-[VARIABLES]
-
-# Tells whether we should check for unused import in __init__ files.
-init-import=no
-
-# A regular expression matching the name of dummy variables (i.e. expectedly
-# not used).
-dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
-
-# List of additional names supposed to be defined in builtins. Remember that
-# you should avoid to define new builtins when possible.
-additional-builtins=
-
-# List of strings which can identify a callback function by name. A callback
-# name must start or end with one of those strings.
-callbacks=cb_,_cb
-
-# List of qualified module names which can have objects that can redefine
-# builtins.
-redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
-
-
-[LOGGING]
-
-# Logging modules to check that the string format arguments are in logging
-# function parameter format
-logging-modules=logging,absl.logging,tensorflow.io.logging
-
-
-[SIMILARITIES]
-
-# Minimum lines number of a similarity.
-min-similarity-lines=4
-
-# Ignore comments when computing similarities.
-ignore-comments=yes
-
-# Ignore docstrings when computing similarities.
-ignore-docstrings=yes
-
-# Ignore imports when computing similarities.
-ignore-imports=no
-
-
-[SPELLING]
-
-# Spelling dictionary name. Available dictionaries: none. To make it working
-# install python-enchant package.
-spelling-dict=
-
-# List of comma separated words that should not be checked.
-spelling-ignore-words=
-
-# A path to a file that contains private dictionary; one word per line.
-spelling-private-dict-file=
-
-# Tells whether to store unknown words to indicated private dictionary in
-# --spelling-private-dict-file option instead of raising a message.
-spelling-store-unknown-words=no
-
-
-[IMPORTS]
-
-# Deprecated modules which should not be used, separated by a comma
-deprecated-modules=regsub,
-                   TERMIOS,
-                   Bastion,
-                   rexec,
-                   sets
-
-# Create a graph of every (i.e. internal and external) dependencies in the
-# given file (report RP0402 must not be disabled)
-import-graph=
-
-# Create a graph of external dependencies in the given file (report RP0402 must
-# not be disabled)
-ext-import-graph=
-
-# Create a graph of internal dependencies in the given file (report RP0402 must
-# not be disabled)
-int-import-graph=
-
-# Force import order to recognize a module as part of the standard
-# compatibility libraries.
-known-standard-library=
-
-# Force import order to recognize a module as part of a third party library.
-known-third-party=enchant, absl
-
-# Analyse import fallback blocks. This can be used to support both Python 2 and
-# 3 compatible code, which means that the block might have code that exists
-# only in one or another interpreter, leading to false positives when analysed.
-analyse-fallback-blocks=no
-
-
-[CLASSES]
-
-# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods=__init__,
-                      __new__,
-                      setUp
-
-# List of member names, which should be excluded from the protected access
-# warning.
-exclude-protected=_asdict,
-                  _fields,
-                  _replace,
-                  _source,
-                  _make
-
-# List of valid names for the first argument in a class method.
-valid-classmethod-first-arg=cls,
-                            class_
-
-# List of valid names for the first argument in a metaclass class method.
-valid-metaclass-classmethod-first-arg=mcs
-
-
-[EXCEPTIONS]
-
-# Exceptions that will emit a warning when being caught. Defaults to
-# "Exception"
-overgeneral-exceptions=StandardError,
-                       Exception,
-                       BaseException

+ 1 - 1
README.md

@@ -83,7 +83,7 @@ GPU: NVIDIA A40, Mistral 7B. Baseline is the same model loaded with text-generat
 ### High Batch Size Performance
 
 > [!NOTE]  
-> The numbers below are the theoritical peak achieved by *only* requesting output tokens at very high batch sizes. At lower batch sizes with much larger prompts, the results will be vastly different.
+> The numbers below are the theoretical peak achieved by *only* requesting output tokens at very high batch sizes. At lower batch sizes with much larger prompts, the results will be vastly different.
 Throughput refers to output tokens per second.
 
 This table is outdated, will be replaced soon.

+ 13 - 12
aphrodite/common/gguf.py

@@ -157,15 +157,16 @@ class GGUFReader:
         offs += 4
         temp_version = self._get(offs, np.uint32)
         if temp_version[0] & 65535 == 0:
-            # If we get 0 here that means it's (probably) a GGUF file created for
-            # the opposite byte order of the machine this script is running on.
+            # If we get 0 here that means it's (probably) a GGUF file created
+            # for the opposite byte order of the machine this script is
+            # running on.
             self.byte_order = 'S'
             temp_version = temp_version.newbyteorder(self.byte_order)
         version = temp_version[0]
         if version not in READER_SUPPORTED_VERSIONS:
             raise ValueError(
-                f'Sorry, file appears to be version {version} which we cannot handle'
-            )
+                f'Sorry, file appears to be version {version} which we cannot '
+                'handle')
         self.fields: OrderedDict[str, ReaderField] = OrderedDict()
         self.tensors: list[ReaderTensor] = []
         offs += self._push_field(
@@ -217,9 +218,8 @@ class GGUFReader:
 
     def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
         if field.name in self.fields:
-            raise KeyError(
-                f'Duplicate {field.name} already in list at offset {field.offset}'
-            )
+            raise KeyError(f'Duplicate {field.name} already in list at offset '
+                           f'{field.offset}')
         self.fields[field.name] = field
         return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
 
@@ -257,8 +257,8 @@ class GGUFReader:
             aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
             data_idxs: list[int] = []
             for idx in range(alen[0]):
-                curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(
-                    offs, raw_itype[0])
+                curr_size, curr_parts, curr_idxs, curr_types = (
+                    self._get_field_parts(offs, raw_itype[0]))
                 if idx == 0:
                     types += curr_types
                 idxs_offs = len(aparts)
@@ -297,8 +297,8 @@ class GGUFReader:
             offs += int(raw_kv_type.nbytes)
             parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
             idxs_offs = len(parts)
-            field_size, field_parts, field_idxs, field_types = self._get_field_parts(
-                offs, raw_kv_type[0])
+            field_size, field_parts, field_idxs, field_types = (
+                self._get_field_parts(offs, raw_kv_type[0]))
             parts += field_parts
             self._push_field(ReaderField(
                 orig_offs,
@@ -325,7 +325,8 @@ class GGUFReader:
         tensors = []
         for field in fields:
             # pylint: disable=unused-variable
-            _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
+            (_name_len, name_data, _n_dims, dims, raw_dtype,
+             offset_tensor) = field.parts
             ggml_type = GGMLQuantizationType(raw_dtype[0])
             n_elems = np.prod(dims)
             block_size, type_size = GGML_QUANT_SIZES[ggml_type]

+ 2 - 1
aphrodite/common/logger.py

@@ -85,7 +85,8 @@ class UvicornLoggingHandler(logging.Handler):
                                                   self.format(record).rstrip())
 
 
-# Uvicorn config for logging. Passed into run when creating all loggers in server
+# Uvicorn config for logging. Passed into run when creating all loggers in
+#server
 UVICORN_LOG_CONFIG = {
     "version": 1,
     "disable_existing_loggers": False,

+ 1 - 0
aphrodite/common/outputs.py

@@ -90,6 +90,7 @@ class RequestOutput:
             sorting_key = lambda seq: seq.get_beam_search_score(
                 seq_group.sampling_params.length_penalty)
         else:
+            # ruff: noqa: E731
             sorting_key = lambda seq: seq.get_cumulative_logprob()
         sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
         top_n_seqs = sorted_seqs[:n]

+ 6 - 6
aphrodite/common/sampling_params.py

@@ -59,15 +59,15 @@ class SamplingParams:
             Exact cutoff is top_a*max_prob**2. Must be in [0,inf], 0 to disable.
         min_p: Float that controls the cutoff for min-p sampling.
             Exact cutoff is min_p*max_prob. Must be in [0,1], 0 to disable.
-        tfs: Float that controls the cummulative approximate curvature of the
+        tfs: Float that controls the cumulative approximate curvature of the
             distribution to retain for Tail Free Sampling.
             Must be in (0, 1]. Set to 1 to disable
-        eta_cutoff: Float that controls the cutoff treshold for Eta sampling
+        eta_cutoff: Float that controls the cutoff threshold for Eta sampling
             (a form of entropy adaptive truncation sampling)
-            treshold is computed as min(eta, sqrt(eta)*entropy(probs)).
+            threshold is computed as min(eta, sqrt(eta)*entropy(probs)).
             Specified in units of 1e-4. Set to 0 to disable
-        epsilon_cutoff: Float that controls the cutoff treshold for
-            Epsilon sampling (simple probability treshold truncation).
+        epsilon_cutoff: Float that controls the cutoff threshold for
+            Epsilon sampling (simple probability threshold truncation).
             Specified in units of 1e-4. Set to 0 to disable.
         typical_p: Float that controls the cumulative probability of tokens
             closest in surprise to the expected surprise to consider.
@@ -99,7 +99,7 @@ class SamplingParams:
             The returned output will not contain the stop strings.
         stop_token_ids: List of tokens that stop the generation when they are
             generated. The returned output will contain the stop tokens unless
-            the stop tokens are sepcial tokens.
+            the stop tokens are special tokens.
         include_stop_str_in_output: Whether to include the stop strings in
             output text. Defaults to False.
         ignore_eos: Whether to ignore the EOS token and continue generating

+ 2 - 1
aphrodite/common/sequence.py

@@ -335,7 +335,8 @@ class SequenceGroup:
             self.metrics.first_token_time = time
 
     def maybe_set_first_scheduled_time(self, time: float) -> None:
-        """Sets the first scheduled time and time in queue for Request level timings."""
+        """Sets the first scheduled time and time in queue for Request level
+        timings."""
         if self.metrics.first_scheduled_time is None:
             self.metrics.first_scheduled_time = time
             self.metrics.time_in_queue = time - self.metrics.arrival_time

+ 3 - 3
aphrodite/common/utils.py

@@ -174,8 +174,8 @@ def get_nvcc_cuda_version() -> Optional[Version]:
         cuda_home = '/usr/local/cuda'
         if os.path.isfile(cuda_home + '/bin/nvcc'):
             logger.info(
-                f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.'
-            )
+                f'CUDA_HOME is not found in the environment. Using {cuda_home} '
+                'as CUDA_HOME.')
         else:
             logger.warning(
                 f'Not found nvcc in {cuda_home}. Skipping cuda version check!')
@@ -196,7 +196,7 @@ def _generate_random_fp8_e5m2(
     # NOTE: Due to NaN and Inf representation for fp8 data type,
     # we may get Inf or NaN if we directly use torch.randint
     # to generate random data for fp8 data.
-    # For example, s.11111.00 in fp8e5m2 format repesents Inf.
+    # For example, s.11111.00 in fp8e5m2 format represents Inf.
     #     | E4M3        | E5M2
     #-----|-------------|-------------------
     # Inf | N/A         | s.11111.00

+ 5 - 4
aphrodite/endpoints/llm.py

@@ -148,10 +148,11 @@ class LLM:
         if isinstance(prompts, str):
             # Convert a single prompt to a list.
             prompts = [prompts]
-        if prompts is not None and prompt_token_ids is not None:
-            if len(prompts) != len(prompt_token_ids):
-                raise ValueError("The lengths of prompts and prompt_token_ids "
-                                 "must be the same.")
+        if prompts is not None and prompt_token_ids is not None and len(
+                prompts) != len(prompt_token_ids):
+            raise ValueError(
+                "The lengths of prompts and prompt_token_ids must "
+                "be the same.")
         if sampling_params is None:
             # Use default sampling params.
             sampling_params = SamplingParams()

+ 14 - 12
aphrodite/endpoints/openai/api_server.py

@@ -14,7 +14,8 @@ from http import HTTPStatus
 from fastapi import Request, APIRouter, Header
 from fastapi.exceptions import RequestValidationError
 from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import JSONResponse, StreamingResponse, Response, HTMLResponse
+from fastapi.responses import (JSONResponse, StreamingResponse, Response,
+                               HTMLResponse)
 from loguru import logger
 
 import aphrodite
@@ -28,7 +29,8 @@ from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams, _SAMPLING_EPS
 from aphrodite.common.utils import random_uuid
 from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
-from aphrodite.endpoints.openai.serving_completions import OpenAIServingCompletion
+from aphrodite.endpoints.openai.serving_completions import (
+    OpenAIServingCompletion)
 from aphrodite.endpoints.openai.protocol import KAIGenerationInputSchema
 from aphrodite.endpoints.openai.serving_engine import LoRA
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
@@ -100,8 +102,8 @@ def parse_args():
         type=str,
         default=None,
         help=
-        "If provided, the server will require this key to be presented in the header."
-    )
+        "If provided, the server will require this key to be presented in the "
+        "header.")
     parser.add_argument(
         "--launch-kobold-api",
         action="store_true",
@@ -125,8 +127,8 @@ def parse_args():
         nargs='+',
         action=LoRAParserAction,
         help=
-        "LoRA module configurations in the format name=path. Multiple modules can be specified."
-    )
+        "LoRA module configurations in the format name=path. Multiple modules "
+        "can be specified.")
     parser.add_argument("--chat-template",
                         type=str,
                         default=None,
@@ -159,9 +161,10 @@ def parse_args():
         help="Additional ASGI middleware to apply to the app. "
         "We accept multiple --middleware arguments. "
         "The value should be an import path. "
-        "If a function is provided, Aphrodite will add it to the server using @app.middleware('http'). "
-        "If a class is provided, Aphrodite will add it to the server using app.add_middleware(). "
-    )
+        "If a function is provided, Aphrodite will add it to the server using "
+        "@app.middleware('http'). "
+        "If a class is provided, Aphrodite will add it to the server using "
+        "app.add_middleware(). ")
 
     parser = AsyncEngineArgs.add_cli_args(parser)
     return parser.parse_args()
@@ -546,9 +549,8 @@ if __name__ == "__main__":
         elif inspect.iscoroutinefunction(imported):
             app.middleware("http")(imported)
         else:
-            raise ValueError(
-                f"Invalid middleware {middleware}. Must be a function or a class."
-            )
+            raise ValueError(f"Invalid middleware {middleware}. Must be a "
+                             "function or a class.")
 
     logger.debug(f"args: {args}")
 

+ 2 - 1
aphrodite/endpoints/openai/protocol.py

@@ -3,7 +3,8 @@
 import time
 from typing import Dict, List, Literal, Optional, Union
 
-from pydantic import AliasChoices, BaseModel, Field, model_validator, root_validator
+from pydantic import (AliasChoices, BaseModel, Field, model_validator,
+                      root_validator)
 import torch
 
 from aphrodite.common.utils import random_uuid

+ 9 - 5
aphrodite/endpoints/openai/serving_chat.py

@@ -13,7 +13,8 @@ from aphrodite.endpoints.openai.protocol import (
     UsageInfo)
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.endpoints.openai.serving_engine import OpenAIServing, LoRA
-from aphrodite.modeling.outlines_decoding import get_guided_decoding_logits_processor
+from aphrodite.modeling.outlines_decoding import (
+    get_guided_decoding_logits_processor)
 
 
 class OpenAIServingChat(OpenAIServing):
@@ -37,7 +38,8 @@ class OpenAIServingChat(OpenAIServing):
         """Completion API similar to OpenAI's API.
 
         See  https://platform.openai.com/docs/api-reference/chat/create
-        for the API specification. This API mimics the OpenAI ChatCompletion API.
+        for the API specification. This API mimics the OpenAI ChatCompletion
+        API.
 
         NOTE: Currently we do not support the following feature:
             - function_call (Users should implement this by themselves)
@@ -115,7 +117,8 @@ class OpenAIServingChat(OpenAIServing):
                 # the result_generator, it needs to be sent as the FIRST
                 # response (by the try...catch).
                 if first_iteration:
-                    # Send first response for each request.n (index) with the role
+                    # Send first response for each request.n (index) with
+                    # the role
                     role = self.get_chat_request_role(request)
                     for i in range(request.n):
                         choice_data = ChatCompletionResponseStreamChoice(
@@ -132,7 +135,8 @@ class OpenAIServingChat(OpenAIServing):
                         data = chunk.model_dump_json(exclude_unset=True)
                         yield f"data: {data}\n\n"
 
-                    # Send response to echo the input portion of the last message
+                    # Send response to echo the input portion of the last
+                    # message
                     if request.echo:
                         last_msg_content = ""
                         if request.messages and isinstance(
@@ -144,7 +148,7 @@ class OpenAIServingChat(OpenAIServing):
 
                         if last_msg_content:
                             for i in range(request.n):
-                                choice_data = ChatCompletionResponseStreamChoice(
+                                choice_data = ChatCompletionResponseStreamChoice(  # noqa
                                     index=i,
                                     delta=DeltaMessage(
                                         content=last_msg_content),

+ 16 - 10
aphrodite/endpoints/openai/serving_completions.py

@@ -1,7 +1,8 @@
 import asyncio
 import time
 from fastapi import Request
-from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple
+from typing import (AsyncGenerator, AsyncIterator, Callable, List, Optional,
+                    Dict, Tuple)
 
 from aphrodite.common.utils import random_uuid
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
@@ -16,7 +17,8 @@ from aphrodite.endpoints.openai.protocol import (
 )
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.endpoints.openai.serving_engine import OpenAIServing, LoRA
-from aphrodite.modeling.outlines_decoding import get_guided_decoding_logits_processor
+from aphrodite.modeling.outlines_decoding import (
+    get_guided_decoding_logits_processor)
 
 TypeTokenIDs = List[int]
 TypeTopLogProbs = List[Optional[Dict[int, float]]]
@@ -43,8 +45,8 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
             prompts = prompt  # case 4: array of token arrays
         else:
             raise ValueError(
-                "prompt must be a string, array of strings, array of tokens, or array of token arrays"
-            )
+                "prompt must be a string, array of strings, array of tokens, "
+                "or array of token arrays")
     return prompt_is_tokens, prompts
 
 
@@ -153,7 +155,8 @@ class OpenAIServingCompletion(OpenAIServing):
             int, RequestOutput]] = merge_async_iterators(*generators)
 
         # Similar to the OpenAI API, when n != best_of, we do not stream the
-        # results. In addition, we do not stream the results when use beam search.
+        # results. In addition, we do not stream the results when use beam
+        # search.
         stream = (request.stream
                   and (request.best_of is None or request.n == request.best_of)
                   and not request.use_beam_search)
@@ -220,7 +223,8 @@ class OpenAIServingCompletion(OpenAIServing):
 
                 for output in res.outputs:
                     i = output.index + prompt_idx * request.n
-                    # TODO: optimize the performance by avoiding full text O(n^2) sending.
+                    # TODO: optimize the performance by avoiding full text
+                    # O(n^2) sending.
 
                     if request.echo and request.max_tokens == 0:
                         # only return the prompt
@@ -228,11 +232,12 @@ class OpenAIServingCompletion(OpenAIServing):
                         delta_token_ids = res.prompt_token_ids
                         top_logprobs = res.prompt_logprobs
                         has_echoed[i] = True
-                    elif request.echo and request.max_tokens > 0 and not has_echoed[
-                            i]:
+                    elif (request.echo and request.max_tokens > 0
+                          and not has_echoed[i]):
                         # echo the prompt and first token
                         delta_text = res.prompt + output.text
-                        delta_token_ids = res.prompt_token_ids + output.token_ids
+                        delta_token_ids = (res.prompt_token_ids +
+                                           output.token_ids)
                         top_logprobs = res.prompt_logprobs + (output.logprobs
                                                               or [])
                         has_echoed[i] = True
@@ -245,7 +250,8 @@ class OpenAIServingCompletion(OpenAIServing):
                             i]:] if output.logprobs else None
 
                     if request.logprobs is not None:
-                        assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested"
+                        assert top_logprobs is not None, "top_logprobs must " \
+                            "be provided when logprobs is requested"
                         logprobs = self._create_logprobs(
                             token_ids=delta_token_ids,
                             top_logprobs=top_logprobs,

+ 6 - 4
aphrodite/endpoints/openai/serving_engine.py

@@ -49,7 +49,8 @@ class OpenAIServing:
             event_loop = None
 
         if event_loop is not None and event_loop.is_running(
-        ):  # If the current is instanced by Ray Serve, there is already a running event loop
+        ):  # If the current is instanced by Ray Serve, there is already
+            # a running event loop
             event_loop.create_task(self._post_init())
         else:  # When using single Aphrodite without engine_use_ray
             asyncio.run(self._post_init())
@@ -188,9 +189,10 @@ class OpenAIServing:
 
         if token_num + request.max_tokens > self.max_model_len:
             raise ValueError(
-                f"This model's maximum context length is {self.max_model_len} tokens. "
-                f"However, you requested {request.max_tokens + token_num} tokens "
-                f"({token_num} in the messages, "
+                f"This model's maximum context length is {self.max_model_len} "
+                "tokens. "
+                f"However, you requested {request.max_tokens + token_num} "
+                f"tokens ({token_num} in the messages, "
                 f"{request.max_tokens} in the completion). "
                 f"Please reduce the length of the messages or completion.", )
         else:

+ 5 - 2
aphrodite/engine/aphrodite_engine.py

@@ -13,7 +13,8 @@ from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics import StatLogger, Stats
-from aphrodite.engine.ray_tools import RayWorkerAphrodite, initialize_cluster, ray
+from aphrodite.engine.ray_tools import (RayWorkerAphrodite, initialize_cluster,
+                                        ray)
 from aphrodite.common.logger import setup_logger
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
@@ -75,7 +76,8 @@ class AphroditeEngine:
         log_stats: bool,
     ) -> None:
         logger.info(
-            f"Initializing the Aphrodite Engine (v{aphrodite.__version__}) with the following config:\n"
+            f"Initializing the Aphrodite Engine (v{aphrodite.__version__}) "
+            "with the following config:\n"
             f"Model = {model_config.model!r}\n"
             f"DataType = {model_config.dtype}\n"
             f"Model Load Format = {model_config.load_format}\n"
@@ -293,6 +295,7 @@ class AphroditeEngine:
                 self.scheduler_config)
 
     def _init_cache(self) -> None:
+        # ruff: noqa: E501
         """Profiles the memory usage and initializes the KV cache.
 
         The engine will first conduct a profiling of the existing memory usage.

+ 10 - 5
aphrodite/engine/metrics.py

@@ -161,9 +161,12 @@ class StatLogger:
     def _log_prometheus_interval(self, prompt_throughput: float,
                                  generation_throughput: float) -> None:
         # Logs metrics to prometheus that are computed every logging_interval.
-        # Support legacy gauge metrics that make throughput calculations on the Aphrodite side.
-        # Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
-        # Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
+        # Support legacy gauge metrics that make throughput calculations on
+        # the Aphrodite side.
+        # Moving forward, we should use counters like counter_prompt_tokens,
+        # counter_generation_tokens
+        # Which log raw data and calculate summaries using rate() on the
+        # grafana/prometheus side.
         self.metrics.gauge_avg_prompt_throughput.labels(
             **self.labels).set(prompt_throughput)
         self.metrics.gauge_avg_generation_throughput.labels(
@@ -184,7 +187,8 @@ class StatLogger:
         # Log locally every local_interval seconds.
         if self._local_interval_elapsed(stats.now):
 
-            # Compute summary metrics for tracked stats (and log them to promethus if applicable).
+            # Compute summary metrics for tracked stats (and log them to
+            # prometheus if applicable).
             prompt_throughput = self._get_throughput(self.num_prompt_tokens,
                                                      now=stats.now)
             generation_throughput = self._get_throughput(
@@ -196,7 +200,8 @@ class StatLogger:
             # Log to stdout.
             logger.info(
                 f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
-                f"Avg generation throughput: {generation_throughput:.1f} tokens/s, "
+                f"Avg generation throughput: {generation_throughput:.1f} "
+                "tokens/s, "
                 f"Running: {stats.num_running} reqs, "
                 f"Swapped: {stats.num_swapped} reqs, "
                 f"Pending: {stats.num_waiting} reqs, "

+ 2 - 2
aphrodite/kv_quant/observer.py

@@ -114,8 +114,8 @@ class KVCacheObserver(GlobalAvailMixin):
             x = x.transpose(1, 2)
         elif x.size(2) != self.num_head or x.size(3) != self.head_dim:
             raise RuntimeError(
-                'Unexpected dimensions for x, expected (bs, num_head, seqlen, head_dim) or (bs, seqlen, num_head, head_dim)'
-            )
+                'Unexpected dimensions for x, expected (bs, num_head, '
+                'seqlen, head_dim) or (bs, seqlen, num_head, head_dim)')
 
         cur_max = x.flatten(0, 1).max(0)[0].cpu()
         cur_min = x.flatten(0, 1).min(0)[0].cpu()

+ 6 - 4
aphrodite/lora/layers.py

@@ -20,7 +20,8 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear,
                                               QKVParallelLinear,
                                               MergedColumnParallelLinear)
-from aphrodite.modeling.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 from aphrodite.modeling.megatron.utils import split_tensor_along_last_dim
@@ -107,7 +108,8 @@ def _apply_lora_packed_nslice(
         lora_b_stacked:    3 element tuple of (num_loras, output_dim, lora_rank)
         indices:           (batch_size)
         output:            (batch_size, q_slice_size + 2*kv_slice_size)
-        output_slices:     n-1 element tuple of (slice_size...), where n is number of slices
+        output_slices:     n-1 element tuple of (slice_size...), where n is
+                           number of slices
     """
     org_output = output
     x = x.view(-1, x.shape[-1])
@@ -843,8 +845,8 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
         # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
         if 32000 < self.base_layer.vocab_size > 33024:
             raise ValueError(
-                "When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
-            )
+                "When using LoRA, vocab size must be 32000 >= vocab_size "
+                "<= 33024")
         self.lora_a_stacked = torch.zeros(
             (
                 max_loras,

+ 3 - 2
aphrodite/lora/models.py

@@ -13,7 +13,8 @@ from torch import nn
 from aphrodite.common.config import LoRAConfig
 from aphrodite.common.utils import LRUCache, in_wsl
 
-from aphrodite.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler
+from aphrodite.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
+                                   from_layer_sampler)
 from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
 from aphrodite.lora.utils import parse_fine_tuned_lora_name, replace_submodule
 
@@ -283,7 +284,7 @@ class LoRAModelManager:
                                               dtype=torch.long,
                                               device="cuda")
         self.offsets = []
-        # 4 is the number of indicies tensors defined above
+        # 4 is the number of indices tensors defined above
         # base_indices, sampler_indices, sampler_indices_padded,
         # embeddings_indices
         self.indices_len = [None] * 4

+ 1 - 1
aphrodite/lora/punica.py

@@ -85,7 +85,7 @@ def add_lora(y: torch.Tensor,
     r = wb_t_all.size(-1)
     if buffer is None:
         # We set the buffer to be float32 by default to avoid
-        # numerical innacuracies that would otherwise happen
+        # numerical inaccuracies that would otherwise happen
         # due to downcasting.
         buffer = torch.zeros((x.size(0), r),
                              dtype=torch.float32,

+ 3 - 3
aphrodite/lora/worker_manager.py

@@ -158,9 +158,9 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
                 f"{self.lora_config.max_lora_rank}.")
         if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
             raise ValueError(
-                f"LoRA added vocab size {lora.extra_vocab_size} is greater than "
-                f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}."
-            )
+                f"LoRA added vocab size {lora.extra_vocab_size} is "
+                "greater than lora_extra_vocab_size "
+                f"{self.lora_config.lora_extra_vocab_size}.")
         return lora
 
     def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:

+ 2 - 1
aphrodite/modeling/hf_downloader.py

@@ -212,7 +212,8 @@ def convert_gguf_to_state_dict(checkpoint, config):
 
     result = GGUFReader(checkpoint)
     # write tensor
-    kv_dim = config.hidden_size // config.num_attention_heads * config.num_key_value_heads
+    kv_dim = (config.hidden_size // config.num_attention_heads *
+              config.num_key_value_heads)
     tensor_mapping = {
         "token_embd": ("model.embed_tokens", config.vocab_size),
         "output": ("lm_head", config.vocab_size),

+ 10 - 5
aphrodite/modeling/layers/linear.py

@@ -154,7 +154,8 @@ class ColumnParallelLinear(torch.nn.Module):
                        skip adding bias but instead return it.
         params_dtype: Data type for the parameters.
         linear_method: (Maybe quantized) linear method.
-        output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3.
+        output_sizes: list of output sizes packed into one output, like for
+                      QKV the list would be size 3.
     """
 
     def __init__(
@@ -304,7 +305,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                     shard_size = shard_size // param.pack_factor
                     shard_offset = shard_offset // param.pack_factor
 
-                    # If marlin, we need to adjust the offset and size to account for the tiling.
+                    # If marlin, we need to adjust the offset and size to
+                    # account for the tiling.
                     shard_size, shard_offset = adjust_marlin_shard(
                         param, shard_size, shard_offset)
 
@@ -326,7 +328,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                 shard_size = shard_size // param.pack_factor
                 shard_offset = shard_offset // param.pack_factor
 
-                # If marlin, we need to adjust the offset and size to account for the tiling.
+                # If marlin, we need to adjust the offset and size to account
+                # for the tiling.
                 shard_size, shard_offset = adjust_marlin_shard(
                     param, shard_size, shard_offset)
 
@@ -443,7 +446,8 @@ class QKVParallelLinear(ColumnParallelLinear):
                     shard_size = shard_size // param.pack_factor
                     shard_offset = shard_offset // param.pack_factor
 
-                    # If marlin, we need to adjust the offset and size to account for the tiling.
+                    # If marlin, we need to adjust the offset and size to
+                    # account for the tiling.
                     shard_size, shard_offset = adjust_marlin_shard(
                         param, shard_size, shard_offset)
 
@@ -472,7 +476,8 @@ class QKVParallelLinear(ColumnParallelLinear):
                 shard_size = shard_size // param.pack_factor
                 shard_offset = shard_offset // param.pack_factor
 
-                # If marlin, we need to adjust the offset and size to account for the tiling.
+                # If marlin, we need to adjust the offset and size to account
+                # for the tiling.
                 shard_size, shard_offset = adjust_marlin_shard(
                     param, shard_size, shard_offset)
 

+ 1 - 1
aphrodite/modeling/layers/quantization/__init__.py

@@ -1,5 +1,5 @@
 from typing import Type
-
+# ruff: noqa: E501
 from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
 from aphrodite.modeling.layers.quantization.aqlm import AQLMConfig
 from aphrodite.modeling.layers.quantization.awq import AWQConfig

+ 6 - 3
aphrodite/modeling/layers/quantization/aqlm.py

@@ -9,7 +9,8 @@ from torch.nn.parameter import Parameter
 
 from aphrodite._C import ops
 from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
-from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
+from aphrodite.modeling.layers.quantization.base_config import (
+    QuantizationConfig)
 
 
 def get_int_dtype(nbits: int) -> torch.dtype:
@@ -126,9 +127,11 @@ class AQLMLinearMethod(LinearMethodBase):
 
         codes = Parameter(
             torch.empty(
-                # There could actually be two pack factors, one along input and one along output,
+                # There could actually be two pack factors, one along input
+                # and one along output,
                 # but we don't currently support out_group_size,
-                # and only the one along output needs to be marked with "packed_dim".
+                # and only the one along output needs to be marked with
+                # "packed_dim".
                 # in order for QKVLinear to work.
                 output_size_per_partition,
                 input_size_per_partition // self.quant_config.pack_factor,

+ 4 - 3
aphrodite/modeling/layers/quantization/awq.py

@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter
 from aphrodite._C import ops
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               set_weight_attrs)
-from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
+from aphrodite.modeling.layers.quantization.base_config import (
+    QuantizationConfig)
 
 
 class AWQConfig(QuantizationConfig):
@@ -49,8 +50,8 @@ class AWQConfig(QuantizationConfig):
     @staticmethod
     def get_config_filenames() -> List[str]:
         return [
-            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
-            "quantize_config.json",  # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq  # pylint: disable=line-too-long
+            "quant_config.json",
+            "quantize_config.json",
         ]
 
     @classmethod

+ 4 - 3
aphrodite/modeling/layers/quantization/bitsandbytes.py

@@ -5,7 +5,8 @@ from typing import List, Dict, Any, Optional, TypeVar, NamedTuple
 from aphrodite._C import ops
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               set_weight_attrs)
-from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
+from aphrodite.modeling.layers.quantization.base_config import (
+    QuantizationConfig)
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
@@ -67,8 +68,8 @@ class BitsandBytesConfig(QuantizationConfig):
     @staticmethod
     def get_config_filenames() -> List[str]:
         return [
-            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
-            "quantize_config.json",  # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
+            "quant_config.json",
+            "quantize_config.json",
         ]
 
     @classmethod

+ 2 - 1
aphrodite/modeling/layers/quantization/exl2.py

@@ -87,7 +87,8 @@ class Exl2LinearMethod(LinearMethodBase):
                        output_size: int,
                        params_dtype: torch.dtype) -> Dict[str, Any]:
         output_size_per_partition = sum(output_partition_sizes)
-        if input_size != input_size_per_partition or output_size != output_size_per_partition:
+        if (input_size != input_size_per_partition
+                or output_size != output_size_per_partition):
             raise ValueError(
                 "Currently exl2 doesn't support tensor parallel yet")
         # The shape of weight is unknown until load state dict

+ 24 - 14
aphrodite/modeling/layers/quantization/marlin.py

@@ -5,7 +5,8 @@ from torch.nn.parameter import Parameter
 
 from aphrodite._C import ops
 from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
-from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
+from aphrodite.modeling.layers.quantization.base_config import (
+    QuantizationConfig)
 
 
 class MarlinConfig(QuantizationConfig):
@@ -22,8 +23,9 @@ class MarlinConfig(QuantizationConfig):
         self.group_size = group_size
         if self.group_size != 128 and self.group_size != -1:
             raise ValueError(
-                "Currently, only group size 128 and -1 (channelwise) is supported for "
-                f"Marlin, but got group_size of {self.group_size}")
+                "Currently, only group size 128 and -1 (channelwise) is "
+                f"supported for Marlin, but got group_size of {self.group_size}"
+            )
 
         # 4 Bits packed into 32 bit datatype.
         self.pack_factor = 32 // 4
@@ -37,7 +39,8 @@ class MarlinConfig(QuantizationConfig):
         # Min in_features dim
         self.min_k_threads = 128
 
-        # Max parallel problems to solve at once (improves large batch performance)
+        # Max parallel problems to solve at once (improves large batch
+        # performance)
         self.max_parallel = 16
 
         # Permutation length used by the marlin kernels.
@@ -109,22 +112,27 @@ class MarlinLinearMethod(LinearMethodBase):
         # Validate output_size_per_partition
         if output_size_per_partition % self.quant_config.min_n_threads != 0:
             raise ValueError(
-                f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}."
-            )
+                "Weight output_size_per_partition = "
+                f"{output_size_per_partition} is not divisible by "
+                f"min_n_threads = {self.quant_config.min_n_threads}.")
         if output_size_per_partition % self.quant_config.pack_factor != 0:
             raise ValueError(
-                f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}."
-            )
+                f"Weight output_size_per_partition = "
+                f"{output_size_per_partition} is not divisible by pack_factor "
+                f"= {self.quant_config.pack_factor}.")
 
         # Validate input_size_per_partition
         if input_size_per_partition % self.quant_config.min_k_threads != 0:
             raise ValueError(
-                f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}."
-            )
-        if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0:
+                f"Weight input_size_per_partition = {input_size_per_partition}"
+                " is not divisible by min_k_threads = "
+                f"{self.quant_config.min_k_threads}.")
+        if (self.quant_config.group_size != -1 and
+                input_size_per_partition % self.quant_config.group_size != 0):
             raise ValueError(
-                f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}."
-            )
+                f"Weight input_size_per_partition = {input_size_per_partition} "
+                "is not divisible by group_size = "
+                f"{self.quant_config.group_size}.")
 
         # Check that we have at least 4 tiles horizontally in the shard
         num_tiles_per_perm = self.quant_config.perm_len // (
@@ -156,7 +164,9 @@ class MarlinLinearMethod(LinearMethodBase):
         )
 
         # Determine if channelwise or not
-        input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size
+        input_groups = (1 if self.quant_config.group_size == -1 else
+                        input_size_per_partition //
+                        self.quant_config.group_size)
 
         scales = Parameter(
             torch.empty(

+ 2 - 1
aphrodite/modeling/layers/quantization/quip.py

@@ -89,7 +89,8 @@ class QuipLinearMethod(LinearMethodBase):
         params_dtype: torch.dtype,
     ) -> Dict[str, Any]:
         output_size_per_partition = sum(output_partition_sizes)
-        if input_size != input_size_per_partition or output_size != output_size_per_partition:
+        if (input_size != input_size_per_partition
+                or output_size != output_size_per_partition):
             raise ValueError(
                 "Currently Quip doesn't support tensor parallel yet")
 

+ 2 - 3
aphrodite/modeling/layers/quantization/quip_utils.py

@@ -4,11 +4,10 @@ from pathlib import Path
 import scipy
 import torch
 from safetensors.torch import load_file
+from contextlib import suppress
 
-try:
+with suppress(ImportError):
     import aphrodite._hadamard_C as hadamard_C
-except ImportError:
-    pass
 
 HADA_TENSORS = load_file(
     Path(__file__).resolve().parent / "hadamard.safetensors")

+ 2 - 1
aphrodite/modeling/layers/quantization/squeezellm.py

@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter
 from aphrodite._C import ops
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               set_weight_attrs)
-from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
+from aphrodite.modeling.layers.quantization.base_config import (
+    QuantizationConfig)
 from aphrodite.common.utils import is_hip
 
 

+ 2 - 2
aphrodite/modeling/layers/rotary_embedding.py

@@ -32,14 +32,14 @@ from aphrodite._C import ops
 
 
 def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
-    """PyTorch-native implemenation."""
+    """PyTorch-native implementation."""
     x1 = x[..., :x.shape[-1] // 2]
     x2 = x[..., x.shape[-1] // 2:]
     return torch.cat((-x2, x1), dim=-1)
 
 
 def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
-    """PyTorch-native implemenation."""
+    """PyTorch-native implementation."""
     x1 = x[..., ::2]
     x2 = x[..., 1::2]
     x = torch.stack((-x2, x1), dim=-1)

+ 3 - 3
aphrodite/modeling/layers/sampler.py

@@ -276,8 +276,8 @@ def _apply_alphabet_soup(
     probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
     min_p_thresholds = probs_sort[:, 0] * m
     top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
-    treshold = torch.maximum(min_p_thresholds, top_a_thresholds)
-    mask = (probs_sort < treshold.unsqueeze(1)
+    threshold = torch.maximum(min_p_thresholds, top_a_thresholds)
+    mask = (probs_sort < threshold.unsqueeze(1)
             )  # Cull logits below the top-a threshold
     mask.logical_or_(
         probs_sum >
@@ -887,6 +887,6 @@ def _mirostat(logits: torch.Tensor, sampling_tensors: SamplingTensors,
     mus = sampling_tensors.miro_mus
 
     logits[idx] = _apply_mirostat_v2(logits[idx], taus, etas,
-                                     mus)  # mus is an inout param, :vomit:
+                                     mus)  # mus is an i/o param, :vomit:
     _miro_store_args(seqids, mus, output_metadata)
     return logits

+ 6 - 5
aphrodite/modeling/layers/triton_kernel/fused_moe.py

@@ -22,9 +22,9 @@ def fused_moe_kernel(
     K,
     EM,
     num_valid_tokens,
-    # The stride variables represent how much to increase the ptr by when moving by 1
-    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
-    # by to get the element one row down (A has M rows).
+    # The stride variables represent how much to increase the ptr by when moving
+    # by 1 element in a particular dimension. E.g. `stride_am` is how much to
+    # increase `a_ptr` by to get the element one row down (A has M rows).
     stride_am,
     stride_ak,
     stride_be,
@@ -202,7 +202,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
     assert topk_weights.stride(1) == 1
     assert sorted_token_ids.stride(0) == 1
 
-    # pylint: disable=unnecessary-lambda-assignment
+    # ruff: noqa: E731
     grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
         'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
 
@@ -249,7 +249,8 @@ def fused_moe(
     - hidden_states (torch.Tensor): The input tensor to the MoE layer.
     - w1 (torch.Tensor): The first set of expert weights.
     - w2 (torch.Tensor): The second set of expert weights.
-    - gating_output (torch.Tensor): The output of the gating operation (before softmax).
+    - gating_output (torch.Tensor): The output of the gating operation (before
+        softmax).
     - topk (int): The number of top-k experts to select.
     - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
     - inplace (bool): If True, perform the operation in-place. Defaults to

+ 1 - 1
aphrodite/modeling/layers/triton_kernel/prefix_prefill.py

@@ -537,7 +537,7 @@ if triton.__version__ >= "2.1.0":
         alibi_start_q = tl.arange(
             0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
         alibi_start_k = cur_batch_ctx_len
-        # # init debuger
+        # # init debugger
         # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
         # offset_db_k = tl.arange(0, BLOCK_N)
         # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]

+ 2 - 2
aphrodite/modeling/megatron/custom_all_reduce.py

@@ -32,14 +32,14 @@ def init_custom_ar() -> None:
     if world_size not in _SUPPORTED_WORLD_SIZES:
         logger.warning(
             "Custom allreduce is disabled due to an unsupported world size: "
-            "%d. Supported world sizes: %s. To slience this warning, specify "
+            "%d. Supported world sizes: %s. To silence this warning, specify "
             "disable_custom_all_reduce=True explicitly.", world_size,
             str(_SUPPORTED_WORLD_SIZES))
         return
     if not _can_p2p(rank, world_size):
         logger.warning(
             "Custom allreduce is disabled because your platform lacks GPU P2P"
-            " capability. To slience this warning, specify "
+            " capability. To silence this warning, specify "
             "disable_custom_all_reduce=True explicitly.")
         return
     _CA_HANDLE = CustomAllreduce(rank, world_size)

+ 1 - 1
aphrodite/modeling/megatron/parallel_state.py

@@ -190,7 +190,7 @@ def get_pipeline_model_parallel_next_rank():
 
 
 def get_pipeline_model_parallel_prev_rank():
-    """Return the global rank that preceeds the caller in the pipeline"""
+    """Return the global rank that precedes the caller in the pipeline"""
     assert _PIPELINE_GLOBAL_RANKS is not None, (
         "Pipeline parallel group is not initialized")
     rank_in_pipeline = get_pipeline_model_parallel_rank()

+ 1 - 1
aphrodite/modeling/models/__init__.py

@@ -4,7 +4,6 @@ from loguru import logger
 
 import torch.nn as nn
 
-from aphrodite.common.logger import setup_logger
 from aphrodite.common.utils import is_hip
 
 # Architecture -> (module, class).
@@ -42,6 +41,7 @@ _MODELS = {
     "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
     "RWForCausalLM": ("falcon", "FalconForCausalLM"),
     "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
+    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
 }
 
 # Models not supported by ROCm.

+ 82 - 54
aphrodite/modeling/models/baichuan.py

@@ -18,6 +18,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only BaiChuan model compatible with HuggingFace weights."""
+
 import math
 from typing import List, Optional, Tuple
 
@@ -28,20 +29,28 @@ from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
-                                              QKVParallelLinear,
-                                              RowParallelLinear,
-                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    LinearMethodBase,
+    MergedColumnParallelLinear,
+    QKVParallelLinear,
+    RowParallelLinear,
+    ColumnParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_rank,
+    get_tensor_model_parallel_world_size,
+)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.transformers_utils.configs.baichuan import BaiChuanConfig
 
@@ -83,27 +92,35 @@ class BaiChuanMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(hidden_size,
-                                                  intermediate_size,
-                                                  bias=False,
-                                                  linear_method=linear_method)
-            self.up_proj = ColumnParallelLinear(hidden_size,
-                                                intermediate_size,
-                                                bias=False,
-                                                linear_method=linear_method)
+            self.gate_proj = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.up_proj = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size, [intermediate_size] * 2,
+                hidden_size,
+                [intermediate_size] * 2,
                 bias=False,
-                linear_method=linear_method)
-        self.down_proj = RowParallelLinear(intermediate_size,
-                                           hidden_size,
-                                           bias=False,
-                                           linear_method=linear_method)
+                linear_method=linear_method,
+            )
+        self.down_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -135,8 +152,8 @@ class BaiChuanAttention(nn.Module):
     ):
         super().__init__()
         self.hidden_size = hidden_size
-        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
-        )
+        tensor_model_parallel_world_size = (
+            get_tensor_model_parallel_world_size())
         self.total_num_heads = num_heads
         assert self.total_num_heads % tensor_model_parallel_world_size == 0
         self.num_heads = (self.total_num_heads //
@@ -170,13 +187,16 @@ class BaiChuanAttention(nn.Module):
             alibi_slopes = alibi_slopes[head_start:head_end].tolist()
 
             scaling = self.head_dim**-0.5
-            self.attn = PagedAttention(self.num_heads,
-                                       self.head_dim,
-                                       scaling,
-                                       alibi_slopes=alibi_slopes)
+            self.attn = PagedAttention(
+                self.num_heads,
+                self.head_dim,
+                scaling,
+                alibi_slopes=alibi_slopes,
+            )
         else:
-            is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-            ) is None else linear_method.quant_config.rope_style()
+            is_neox_style = (True if linear_method is None
+                             or linear_method.quant_config.rope_style() is None
+                             else linear_method.quant_config.rope_style())
             self.rotary_emb = get_rope(
                 self.head_dim,
                 rotary_dim=self.head_dim,
@@ -207,10 +227,12 @@ class BaiChuanAttention(nn.Module):
 
 class BaiChuanDecoderLayer(nn.Module):
 
-    def __init__(self,
-                 config: BaiChuanConfig,
-                 position_embedding: str,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: BaiChuanConfig,
+        position_embedding: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         self.hidden_size = config.hidden_size
         rope_theta = getattr(config, "rope_theta", 10000)
@@ -266,10 +288,12 @@ class BaiChuanDecoderLayer(nn.Module):
 
 class BaiChuanModel(nn.Module):
 
-    def __init__(self,
-                 config: BaiChuanConfig,
-                 position_embedding: str,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: BaiChuanConfig,
+        position_embedding: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         self.config = config
         self.padding_idx = config.pad_token_id
@@ -308,10 +332,12 @@ class BaiChuanModel(nn.Module):
 
 class BaiChuanBaseForCausalLM(nn.Module):
 
-    def __init__(self,
-                 config,
-                 position_embedding: str,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config,
+        position_embedding: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         self.config = config
         self.linear_method = linear_method
@@ -341,18 +367,20 @@ class BaiChuanBaseForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
@@ -361,7 +389,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
             if "rotary_emb.inv_freq" in name:
                 continue
             if name == "lm_head.weight":
-                # Unlike Baichuan, Baichuan2 normalizes the head weights. Refer to:
+                # Unlike Baichuan, Baichuan2 normalizes the head weights. Ref.:
                 # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
                 # Distinguish between Baichuan and Baichuan2 by checking the
                 # vocab size.
@@ -370,7 +398,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
                     loaded_weight = torch.nn.functional.normalize(
                         loaded_weight)
 
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 78 - 53
aphrodite/modeling/models/deepseek.py

@@ -22,33 +22,42 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Deepseek model."""
+
 from typing import Any, Dict, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import PretrainedConfig
 
-from aphrodite.modeling.megatron import InputMetadata
+from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
-                                              ReplicatedLinear,
-                                              QKVParallelLinear,
-                                              RowParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    LinearMethodBase,
+    MergedColumnParallelLinear,
+    ReplicatedLinear,
+    QKVParallelLinear,
+    RowParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_all_reduce)
+    tensor_model_parallel_all_reduce, )
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_rank,
+    get_tensor_model_parallel_world_size,
+)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -66,14 +75,18 @@ class DeepseekMLP(nn.Module):
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
-            hidden_size, [intermediate_size] * 2,
+            hidden_size,
+            [intermediate_size] * 2,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.down_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
             bias=False,
-            linear_method=linear_method)
-        self.down_proj = RowParallelLinear(intermediate_size,
-                                           hidden_size,
-                                           bias=False,
-                                           linear_method=linear_method,
-                                           reduce_results=reduce_results)
+            linear_method=linear_method,
+            reduce_results=reduce_results,
+        )
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -105,22 +118,26 @@ class DeepseekMoE(nn.Module):
                 f"the number of experts {self.n_routed_experts}.")
 
         self.experts = nn.ModuleList([
-            DeepseekMLP(hidden_size=config.hidden_size,
-                        intermediate_size=config.moe_intermediate_size,
-                        hidden_act=config.hidden_act,
-                        linear_method=linear_method,
-                        reduce_results=False)
-            for idx in range(self.n_routed_experts)
+            DeepseekMLP(
+                hidden_size=config.hidden_size,
+                intermediate_size=config.moe_intermediate_size,
+                hidden_act=config.hidden_act,
+                linear_method=linear_method,
+                reduce_results=False,
+            ) for idx in range(self.n_routed_experts)
         ])
         self.pack_params()
 
-        self.gate = ReplicatedLinear(config.hidden_size,
-                                     self.n_routed_experts,
-                                     bias=False,
-                                     linear_method=None)
+        self.gate = ReplicatedLinear(
+            config.hidden_size,
+            self.n_routed_experts,
+            bias=False,
+            linear_method=None,
+        )
 
         if config.n_shared_experts is not None:
-            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
+            intermediate_size = (config.moe_intermediate_size *
+                                 config.n_shared_experts)
             self.shared_experts = DeepseekMLP(
                 hidden_size=config.hidden_size,
                 intermediate_size=intermediate_size,
@@ -155,13 +172,15 @@ class DeepseekMoE(nn.Module):
             shared_output = self.shared_experts(hidden_states)
         # router_logits: (batch * sequence_length, n_experts)
         router_logits, _ = self.gate(hidden_states)
-        final_hidden_states = fused_moe(hidden_states,
-                                        self.w1,
-                                        self.w2,
-                                        router_logits,
-                                        self.top_k,
-                                        renormalize=self.config.norm_topk_prob,
-                                        inplace=True)
+        final_hidden_states = fused_moe(
+            hidden_states,
+            self.w1,
+            self.w2,
+            router_logits,
+            self.top_k,
+            renormalize=self.config.norm_topk_prob,
+            inplace=True,
+        )
 
         if self.config.n_shared_experts is not None:
             final_hidden_states = final_hidden_states + shared_output
@@ -230,10 +249,12 @@ class DeepseekAttention(nn.Module):
             base=rope_theta,
             rope_scaling=rope_scaling,
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   self.scaling,
-                                   num_kv_heads=self.num_kv_heads)
+        self.attn = PagedAttention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+        )
 
     def forward(
         self,
@@ -274,8 +295,9 @@ class DeepseekDecoderLayer(nn.Module):
             max_position_embeddings=max_position_embeddings,
             linear_method=linear_method,
         )
-        if (config.n_routed_experts is not None and  \
-            layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0):
+        if (config.n_routed_experts is not None
+                and layer_idx >= config.first_k_dense_replace
+                and layer_idx % config.moe_layer_freq == 0):
             self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
         else:
             self.mlp = DeepseekMLP(
@@ -393,11 +415,13 @@ class DeepseekForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -413,10 +437,11 @@ class DeepseekForCausalLM(nn.Module):
                 cache_dir,
                 load_format,
                 revision,
-                fall_back_to_pt=False):
+                fall_back_to_pt=False,
+        ):
             if "rotary_emb.inv_freq" in name:
                 continue
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
@@ -424,8 +449,8 @@ class DeepseekForCausalLM(nn.Module):
                 if name.endswith(".bias") and name not in params_dict:
                     continue
                 # Skip experts that are not assigned to this worker.
-                if (("mlp.experts." in name or "mlp.shared_experts." in name)
-                        and name not in params_dict):
+                if ("mlp.experts." in name or "mlp.shared_experts."
+                        in name) and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
@@ -436,8 +461,8 @@ class DeepseekForCausalLM(nn.Module):
                 if name.endswith(".bias") and name not in params_dict:
                     continue
                 # Skip experts that are not assigned to this worker.
-                if (("mlp.experts." in name or "mlp.shared_experts." in name)
-                        and name not in params_dict):
+                if ("mlp.experts." in name or "mlp.shared_experts."
+                        in name) and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",

+ 86 - 59
aphrodite/modeling/models/gemma.py

@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Gemma model compatible with HuggingFace weights."""
+
 from typing import List, Optional, Tuple
 
 import torch
@@ -25,20 +26,26 @@ from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
-                                              QKVParallelLinear,
-                                              RowParallelLinear,
-                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    LinearMethodBase,
+    MergedColumnParallelLinear,
+    QKVParallelLinear,
+    RowParallelLinear,
+    ColumnParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -53,27 +60,35 @@ class GemmaMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(hidden_size,
-                                                  intermediate_size,
-                                                  bias=False,
-                                                  linear_method=linear_method)
-            self.up_proj = ColumnParallelLinear(hidden_size,
-                                                intermediate_size,
-                                                bias=False,
-                                                linear_method=linear_method)
+            self.gate_proj = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.up_proj = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size, [intermediate_size] * 2,
+                hidden_size,
+                [intermediate_size] * 2,
                 bias=False,
-                linear_method=linear_method)
-        self.down_proj = RowParallelLinear(intermediate_size,
-                                           hidden_size,
-                                           bias=False,
-                                           linear_method=linear_method)
+                linear_method=linear_method,
+            )
+        self.down_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
         self.act_fn = GeluAndMul()
 
     def forward(self, x):
@@ -90,14 +105,16 @@ class GemmaMLP(nn.Module):
 
 class GemmaAttention(nn.Module):
 
-    def __init__(self,
-                 hidden_size: int,
-                 num_heads: int,
-                 num_kv_heads: int,
-                 head_dim: int,
-                 max_position_embeddings: int = 8192,
-                 rope_theta: float = 10000,
-                 linear_method: Optional[LinearMethodBase] = None) -> None:
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        head_dim: int,
+        max_position_embeddings: int = 8192,
+        rope_theta: float = 10000,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -120,21 +137,27 @@ class GemmaAttention(nn.Module):
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
 
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(hidden_size,
-                                               self.q_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=False,
-                                               linear_method=linear_method)
+            self.q_proj = ColumnParallelLinear(
+                hidden_size,
+                self.q_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.k_proj = ColumnParallelLinear(
+                hidden_size,
+                self.kv_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.v_proj = ColumnParallelLinear(
+                hidden_size,
+                self.kv_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.qkv_proj = QKVParallelLinear(
@@ -158,10 +181,12 @@ class GemmaAttention(nn.Module):
             base=self.rope_theta,
             is_neox_style=True,
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   self.scaling,
-                                   num_kv_heads=self.num_kv_heads)
+        self.attn = PagedAttention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+        )
 
     def forward(
         self,
@@ -323,11 +348,13 @@ class GemmaForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -336,8 +363,8 @@ class GemmaForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         loaded_params = set()
@@ -355,7 +382,7 @@ class GemmaForCausalLM(nn.Module):
                     weight_loader = getattr(lm_head_param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(lm_head_param, loaded_weight)
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 53 - 35
aphrodite/modeling/models/gpt_j.py

@@ -17,6 +17,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-J model compatible with HuggingFace weights."""
+
 from typing import List, Optional, Tuple
 
 import torch
@@ -26,19 +27,25 @@ from transformers import GPTJConfig
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.attention import PagedAttention
-from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
-                                              QKVParallelLinear,
-                                              RowParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    ColumnParallelLinear,
+    LinearMethodBase,
+    QKVParallelLinear,
+    RowParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -56,21 +63,27 @@ class GPTJAttention(nn.Module):
         self.hidden_size = config.hidden_size
         self.head_size = self.hidden_size // self.total_num_heads
 
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(config.hidden_size,
-                                               config.hidden_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(config.hidden_size,
-                                               config.hidden_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(config.hidden_size,
-                                               config.hidden_size,
-                                               bias=False,
-                                               linear_method=linear_method)
+            self.q_proj = ColumnParallelLinear(
+                config.hidden_size,
+                config.hidden_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.k_proj = ColumnParallelLinear(
+                config.hidden_size,
+                config.hidden_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.v_proj = ColumnParallelLinear(
+                config.hidden_size,
+                config.hidden_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.qkv_proj = QKVParallelLinear(
@@ -166,7 +179,8 @@ class GPTJBlock(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
-        inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner
+        inner_dim = (4 * config.n_embd
+                     if config.n_inner is None else config.n_inner)
         self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
         self.attn = GPTJAttention(config, linear_method)
         self.mlp = GPTJMLP(inner_dim, config, linear_method)
@@ -240,10 +254,12 @@ class GPTJForCausalLM(nn.Module):
         self.linear_method = linear_method
         assert not config.tie_word_embeddings
         self.transformer = GPTJModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.n_embd,
-                                      bias=True,
-                                      linear_method=linear_method)
+        self.lm_head = ParallelLMHead(
+            config.vocab_size,
+            config.n_embd,
+            bias=True,
+            linear_method=linear_method,
+        )
         self.sampler = Sampler(config.vocab_size)
 
     def forward(
@@ -266,11 +282,13 @@ class GPTJForCausalLM(nn.Module):
                                    sampling_metadata, self.lm_head.bias)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -279,8 +297,8 @@ class GPTJForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
@@ -288,7 +306,7 @@ class GPTJForCausalLM(nn.Module):
                 self.config):
             if "attn.bias" in name or "attn.masked_bias" in name:
                 continue
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 25 - 15
aphrodite/modeling/models/gpt_neox.py

@@ -17,6 +17,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
+
 from typing import List, Optional, Tuple
 
 import torch
@@ -26,19 +27,25 @@ from transformers import GPTNeoXConfig
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.attention import PagedAttention
-from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
-                                              QKVParallelLinear,
-                                              RowParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    ColumnParallelLinear,
+    LinearMethodBase,
+    QKVParallelLinear,
+    RowParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -82,8 +89,9 @@ class GPTNeoXAttention(nn.Module):
         rope_theta = getattr(config, "rope_theta", 10000)
         max_position_embeddings = getattr(config, "max_position_embeddings",
                                           8192)
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
+        is_neox_style = (True if linear_method is None
+                         or linear_method.quant_config.rope_style() is None
+                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_size,
             rotary_dim=rotary_dim,
@@ -262,11 +270,13 @@ class GPTNeoXForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path, cache_dir, load_format, revision,

+ 60 - 41
aphrodite/modeling/models/internlm2.py

@@ -9,20 +9,26 @@ from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              ColumnParallelLinear,
-                                              MergedColumnParallelLinear,
-                                              QKVParallelLinear,
-                                              RowParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    LinearMethodBase,
+    ColumnParallelLinear,
+    MergedColumnParallelLinear,
+    QKVParallelLinear,
+    RowParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -38,27 +44,35 @@ class InternLM2MLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.w1 = ColumnParallelLinear(hidden_size,
-                                           intermediate_size,
-                                           bias=False,
-                                           linear_method=linear_method)
-            self.w3 = ColumnParallelLinear(hidden_size,
-                                           intermediate_size,
-                                           bias=False,
-                                           linear_method=linear_method)
+            self.w1 = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.w3 = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size, [intermediate_size] * 2,
+                hidden_size,
+                [intermediate_size] * 2,
                 bias=False,
-                linear_method=linear_method)
-        self.w2 = RowParallelLinear(intermediate_size,
-                                    hidden_size,
-                                    bias=False,
-                                    linear_method=linear_method)
+                linear_method=linear_method,
+            )
+        self.w2 = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -133,10 +147,12 @@ class InternLM2Attention(nn.Module):
             base=rope_theta,
             rope_scaling=rope_scaling,
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   self.scaling,
-                                   num_kv_heads=self.num_kv_heads)
+        self.attn = PagedAttention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+        )
 
     def forward(
         self,
@@ -296,25 +312,27 @@ class InternLM2ForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "w1", 0),
             ("gate_up_proj", "w3", 1),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path, cache_dir, load_format, revision):
             if "rotary_emb.inv_freq" in name:
                 continue
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
@@ -332,7 +350,8 @@ class InternLM2ForCausalLM(nn.Module):
                 param = params_dict[name]
                 if "wqkv" in name:
                     config = self.config
-                    kv_groups = config.num_attention_heads // config.num_key_value_heads
+                    kv_groups = (config.num_attention_heads //
+                                 config.num_key_value_heads)
                     head_dim = config.hidden_size // config.num_attention_heads
                     loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
                                                        head_dim,
@@ -343,9 +362,9 @@ class InternLM2ForCausalLM(nn.Module):
                     wk = wk.reshape(-1, wk.shape[-1])
                     wv = wv.reshape(-1, wv.shape[-1])
                     weight_loader = param.weight_loader
-                    weight_loader(param, wq, 'q')
-                    weight_loader(param, wk, 'k')
-                    weight_loader(param, wv, 'v')
+                    weight_loader(param, wq, "q")
+                    weight_loader(param, wk, "k")
+                    weight_loader(param, wv, "v")
                 else:
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)

+ 77 - 52
aphrodite/modeling/models/llama.py

@@ -21,6 +21,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only LLaMA model compatible with HuggingFace weights."""
+
 from typing import Any, Dict, List, Optional, Tuple
 
 import torch
@@ -31,20 +32,27 @@ from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
-                                              QKVParallelLinear,
-                                              RowParallelLinear,
-                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    LinearMethodBase,
+    MergedColumnParallelLinear,
+    QKVParallelLinear,
+    RowParallelLinear,
+    ColumnParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+    DEFAULT_VOCAB_PADDING_SIZE,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.common.config import LoRAConfig
 
@@ -61,27 +69,35 @@ class LlamaMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(hidden_size,
-                                                  intermediate_size,
-                                                  bias=False,
-                                                  linear_method=linear_method)
-            self.up_proj = ColumnParallelLinear(hidden_size,
-                                                intermediate_size,
-                                                bias=False,
-                                                linear_method=linear_method)
+            self.gate_proj = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.up_proj = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size, [intermediate_size] * 2,
+                hidden_size,
+                [intermediate_size] * 2,
                 bias=False,
-                linear_method=linear_method)
-        self.down_proj = RowParallelLinear(intermediate_size,
-                                           hidden_size,
-                                           bias=False,
-                                           linear_method=linear_method)
+                linear_method=linear_method,
+            )
+        self.down_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -136,21 +152,25 @@ class LlamaAttention(nn.Module):
         self.rope_theta = rope_theta
         self.max_position_embeddings = max_position_embeddings
 
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
             self.q_proj = ColumnParallelLinear(hidden_size,
                                                self.q_size,
                                                bias=bias,
                                                linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=bias,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=bias,
-                                               linear_method=linear_method)
+            self.k_proj = ColumnParallelLinear(
+                hidden_size,
+                self.kv_size,
+                bias=bias,
+                linear_method=linear_method,
+            )
+            self.v_proj = ColumnParallelLinear(
+                hidden_size,
+                self.kv_size,
+                bias=bias,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.qkv_proj = QKVParallelLinear(
@@ -168,8 +188,9 @@ class LlamaAttention(nn.Module):
             linear_method=linear_method,
         )
 
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
+        is_neox_style = (True if linear_method is None
+                         or linear_method.quant_config.rope_style() is None
+                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
@@ -178,11 +199,13 @@ class LlamaAttention(nn.Module):
             rope_scaling=rope_scaling,
             is_neox_style=is_neox_style,
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   self.scaling,
-                                   num_kv_heads=self.num_kv_heads,
-                                   sliding_window=sliding_window)
+        self.attn = PagedAttention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+            sliding_window=sliding_window,
+        )
 
     def forward(
         self,
@@ -287,8 +310,8 @@ class LlamaModel(nn.Module):
         super().__init__()
         self.config = config
         self.padding_idx = config.pad_token_id
-        lora_vocab = (lora_config.lora_extra_vocab_size *
-                      (lora_config.max_loras or 1)) if lora_config else 0
+        lora_vocab = ((lora_config.lora_extra_vocab_size *
+                       (lora_config.max_loras or 1)) if lora_config else 0)
         self.vocab_size = config.vocab_size + lora_vocab
         self.org_vocab_size = config.vocab_size
         self.embed_tokens = VocabParallelEmbedding(
@@ -400,11 +423,13 @@ class LlamaForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -413,8 +438,8 @@ class LlamaForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
@@ -427,7 +452,7 @@ class LlamaForCausalLM(nn.Module):
                 # Models trained using ColossalAI may include these tensors in
                 # the checkpoint. Skip them.
                 continue
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 129 - 84
aphrodite/modeling/models/mixtral.py

@@ -22,6 +22,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Mixtral model."""
+
 from typing import List, Optional, Tuple
 
 import torch
@@ -33,23 +34,32 @@ from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              QKVParallelLinear,
-                                              ReplicatedLinear,
-                                              RowParallelLinear,
-                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    LinearMethodBase,
+    QKVParallelLinear,
+    ReplicatedLinear,
+    RowParallelLinear,
+    ColumnParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+    DEFAULT_VOCAB_PADDING_SIZE,
+)
 from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_all_reduce)
+    tensor_model_parallel_all_reduce, )
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_rank,
+    get_tensor_model_parallel_world_size,
+)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -84,34 +94,51 @@ class MixtralMoE(nn.Module):
             params_dtype = torch.get_default_dtype()
         self.params_dtype = params_dtype
 
-        self.gate = ReplicatedLinear(self.hidden_size,
-                                     self.num_total_experts,
-                                     bias=False,
-                                     params_dtype=self.params_dtype,
-                                     linear_method=None)
+        self.gate = ReplicatedLinear(
+            self.hidden_size,
+            self.num_total_experts,
+            bias=False,
+            params_dtype=self.params_dtype,
+            linear_method=None,
+        )
 
         self.ws = nn.Parameter(
-            torch.empty(self.num_total_experts,
-                        2 * self.intermediate_size,
-                        self.hidden_size,
-                        device="cuda",
-                        dtype=self.params_dtype))
+            torch.empty(
+                self.num_total_experts,
+                2 * self.intermediate_size,
+                self.hidden_size,
+                device="cuda",
+                dtype=self.params_dtype,
+            ))
         self.w2s = nn.Parameter(
-            torch.empty(self.num_total_experts,
-                        self.hidden_size,
-                        self.intermediate_size,
-                        device="cuda",
-                        dtype=self.params_dtype))
-
-        set_weight_attrs(self.ws, {
-            "weight_loader": self.weight_loader,
-        })
-        set_weight_attrs(self.w2s, {
-            "weight_loader": self.weight_loader,
-        })
-
-    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
-                      weight_name: str, expert_id: int):
+            torch.empty(
+                self.num_total_experts,
+                self.hidden_size,
+                self.intermediate_size,
+                device="cuda",
+                dtype=self.params_dtype,
+            ))
+
+        set_weight_attrs(
+            self.ws,
+            {
+                "weight_loader": self.weight_loader,
+            },
+        )
+        set_weight_attrs(
+            self.w2s,
+            {
+                "weight_loader": self.weight_loader,
+            },
+        )
+
+    def weight_loader(
+        self,
+        param: nn.Parameter,
+        loaded_weight: torch.Tensor,
+        weight_name: str,
+        expert_id: int,
+    ):
         tp_rank = get_tensor_model_parallel_rank()
         param_data = param.data
         shard_size = self.intermediate_size
@@ -119,8 +146,8 @@ class MixtralMoE(nn.Module):
         if weight_name.endswith("w1.weight"):
             param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
         if weight_name.endswith("w3.weight"):
-            param_data[expert_id,
-                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
+            param_data[expert_id, shard_size:2 *
+                       shard_size, :] = (loaded_weight[shard, :])
         if weight_name.endswith("w2.weight"):
             param_data[expert_id, :, :] = loaded_weight[:, shard]
 
@@ -129,13 +156,15 @@ class MixtralMoE(nn.Module):
         hidden_states = hidden_states.view(-1, self.hidden_size)
         # router_logits: (batch * sequence_length, n_experts)
         router_logits, _ = self.gate(hidden_states)
-        final_hidden_states = fused_moe(hidden_states,
-                                        self.ws,
-                                        self.w2s,
-                                        router_logits,
-                                        self.top_k,
-                                        renormalize=True,
-                                        inplace=True)
+        final_hidden_states = fused_moe(
+            hidden_states,
+            self.ws,
+            self.w2s,
+            router_logits,
+            self.top_k,
+            renormalize=True,
+            inplace=True,
+        )
 
         if self.tp_size > 1:
             final_hidden_states = tensor_model_parallel_all_reduce(
@@ -147,14 +176,16 @@ class MixtralMoE(nn.Module):
 
 class MixtralAttention(nn.Module):
 
-    def __init__(self,
-                 hidden_size: int,
-                 num_heads: int,
-                 num_kv_heads: int,
-                 max_position: int = 4096 * 32,
-                 rope_theta: float = 10000,
-                 linear_method: Optional[LinearMethodBase] = None,
-                 sliding_window: Optional[int] = None) -> None:
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        max_position: int = 4096 * 32,
+        rope_theta: float = 10000,
+        linear_method: Optional[LinearMethodBase] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -178,21 +209,27 @@ class MixtralAttention(nn.Module):
         self.rope_theta = rope_theta
         self.sliding_window = sliding_window
 
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(hidden_size,
-                                               self.q_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=False,
-                                               linear_method=linear_method)
+            self.q_proj = ColumnParallelLinear(
+                hidden_size,
+                self.q_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.k_proj = ColumnParallelLinear(
+                hidden_size,
+                self.kv_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.v_proj = ColumnParallelLinear(
+                hidden_size,
+                self.kv_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.qkv_proj = QKVParallelLinear(
@@ -209,8 +246,9 @@ class MixtralAttention(nn.Module):
             bias=False,
             linear_method=linear_method,
         )
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
+        is_neox_style = (True if linear_method is None
+                         or linear_method.quant_config.rope_style() is None
+                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
@@ -266,12 +304,14 @@ class MixtralDecoderLayer(nn.Module):
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
             sliding_window=config.sliding_window,
-            linear_method=linear_method)
+            linear_method=linear_method,
+        )
         self.block_sparse_moe = MixtralMoE(
             num_experts=config.num_local_experts,
             top_k=config.num_experts_per_tok,
             hidden_size=config.hidden_size,
-            intermediate_size=config.intermediate_size)
+            intermediate_size=config.intermediate_size,
+        )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
         self.post_attention_layernorm = RMSNorm(config.hidden_size,
@@ -316,8 +356,8 @@ class MixtralModel(nn.Module):
     ) -> None:
         super().__init__()
         self.padding_idx = config.pad_token_id
-        lora_vocab = (lora_config.lora_extra_vocab_size *
-                      (lora_config.max_loras or 1)) if lora_config else 0
+        lora_vocab = ((lora_config.lora_extra_vocab_size *
+                       (lora_config.max_loras or 1)) if lora_config else 0)
         self.vocab_size = config.vocab_size + lora_vocab
         self.org_vocab_size = config.vocab_size
 
@@ -420,26 +460,30 @@ class MixtralForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
 
         expert_params_mapping = [
             # (param_name, weight_name, expert_id)
-            ("ws" if weight_name in ["w1", "w3"] else "w2s",
-             f"experts.{expert_id}.{weight_name}.weight", expert_id)
-            for expert_id in range(self.config.num_local_experts)
+            (
+                "ws" if weight_name in ["w1", "w3"] else "w2s",
+                f"experts.{expert_id}.{weight_name}.weight",
+                expert_id,
+            ) for expert_id in range(self.config.num_local_experts)
             for weight_name in ["w1", "w2", "w3"]
         ]
 
@@ -450,11 +494,12 @@ class MixtralForCausalLM(nn.Module):
                 load_format,
                 revision,
                 self.config,
-                fall_back_to_pt=False):
+                fall_back_to_pt=False,
+        ):
             if "rotary_emb.inv_freq" in name:
                 continue
 
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 94 - 64
aphrodite/modeling/models/mixtral_quant.py

@@ -22,6 +22,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Mixtral model."""
+
 from typing import List, Optional, Tuple
 
 import numpy as np
@@ -35,22 +36,30 @@ from transformers import MixtralConfig
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              ReplicatedLinear,
-                                              QKVParallelLinear,
-                                              RowParallelLinear,
-                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    LinearMethodBase,
+    ReplicatedLinear,
+    QKVParallelLinear,
+    RowParallelLinear,
+    ColumnParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_all_reduce)
+    tensor_model_parallel_all_reduce, )
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_rank,
+    get_tensor_model_parallel_world_size,
+)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -70,18 +79,24 @@ class MixtralMLP(nn.Module):
         self.ffn_dim = intermediate_size
         self.hidden_dim = hidden_size
 
-        self.w1 = ReplicatedLinear(self.hidden_dim,
-                                   self.ffn_dim,
-                                   bias=False,
-                                   linear_method=linear_method)
-        self.w2 = ReplicatedLinear(self.ffn_dim,
-                                   self.hidden_dim,
-                                   bias=False,
-                                   linear_method=linear_method)
-        self.w3 = ReplicatedLinear(self.hidden_dim,
-                                   self.ffn_dim,
-                                   bias=False,
-                                   linear_method=linear_method)
+        self.w1 = ReplicatedLinear(
+            self.hidden_dim,
+            self.ffn_dim,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.w2 = ReplicatedLinear(
+            self.ffn_dim,
+            self.hidden_dim,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.w3 = ReplicatedLinear(
+            self.hidden_dim,
+            self.ffn_dim,
+            bias=False,
+            linear_method=linear_method,
+        )
 
         # TODO: Use Aphrodite's SiluAndMul
         self.act_fn = nn.SiLU()
@@ -120,17 +135,20 @@ class MixtralMoE(nn.Module):
                 f"Rank {self.rank} has no experts assigned to it.")
 
         self.experts = nn.ModuleList([
-            MixtralMLP(self.num_total_experts,
-                       config.hidden_size,
-                       config.intermediate_size,
-                       linear_method=linear_method)
-            if idx in self.expert_indicies else None
+            MixtralMLP(
+                self.num_total_experts,
+                config.hidden_size,
+                config.intermediate_size,
+                linear_method=linear_method,
+            ) if idx in self.expert_indicies else None
             for idx in range(self.num_total_experts)
         ])
-        self.gate = ReplicatedLinear(config.hidden_size,
-                                     self.num_total_experts,
-                                     bias=False,
-                                     linear_method=None)
+        self.gate = ReplicatedLinear(
+            config.hidden_size,
+            self.num_total_experts,
+            bias=False,
+            linear_method=None,
+        )
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         batch_size, sequence_length, hidden_dim = hidden_states.shape
@@ -147,7 +165,7 @@ class MixtralMoE(nn.Module):
         final_hidden_states = None
         for expert_idx in self.expert_indicies:
             expert_layer = self.experts[expert_idx]
-            expert_mask = (selected_experts == expert_idx)
+            expert_mask = selected_experts == expert_idx
             expert_weights = (routing_weights * expert_mask).sum(dim=-1,
                                                                  keepdim=True)
 
@@ -164,14 +182,16 @@ class MixtralMoE(nn.Module):
 
 class MixtralAttention(nn.Module):
 
-    def __init__(self,
-                 hidden_size: int,
-                 num_heads: int,
-                 num_kv_heads: int,
-                 max_position: int = 4096 * 32,
-                 rope_theta: float = 10000,
-                 linear_method: Optional[LinearMethodBase] = None,
-                 sliding_window: Optional[int] = None) -> None:
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        max_position: int = 4096 * 32,
+        rope_theta: float = 10000,
+        linear_method: Optional[LinearMethodBase] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -195,21 +215,27 @@ class MixtralAttention(nn.Module):
         self.rope_theta = rope_theta
         self.sliding_window = sliding_window
 
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(hidden_size,
-                                               self.q_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=False,
-                                               linear_method=linear_method)
+            self.q_proj = ColumnParallelLinear(
+                hidden_size,
+                self.q_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.k_proj = ColumnParallelLinear(
+                hidden_size,
+                self.kv_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.v_proj = ColumnParallelLinear(
+                hidden_size,
+                self.kv_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.qkv_proj = QKVParallelLinear(
@@ -281,7 +307,8 @@ class MixtralDecoderLayer(nn.Module):
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
             sliding_window=config.sliding_window,
-            linear_method=linear_method)
+            linear_method=linear_method,
+        )
         self.block_sparse_moe = MixtralMoE(config=config,
                                            linear_method=linear_method)
         self.input_layernorm = RMSNorm(config.hidden_size,
@@ -396,19 +423,21 @@ class MixtralForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
 
         params_dict = dict(self.named_parameters())
@@ -418,10 +447,11 @@ class MixtralForCausalLM(nn.Module):
                 load_format,
                 revision,
                 self.config,
-                fall_back_to_pt=False):
+                fall_back_to_pt=False,
+        ):
             if "rotary_emb.inv_freq" in name:
                 continue
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 45 - 30
aphrodite/modeling/models/olmo.py

@@ -37,6 +37,7 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 """Inference-only OLMo model compatible with HuggingFace weights."""
+
 from typing import List, Optional, Tuple
 
 import torch
@@ -54,7 +55,9 @@ from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
@@ -81,7 +84,8 @@ class SwiGLU(nn.Module):
 
 class OlmoAttention(nn.Module):
     """
-    This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
+    This is the attention block where the output is computed as
+    ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
     (plus another skip connection).
     """
 
@@ -94,11 +98,12 @@ class OlmoAttention(nn.Module):
         self.config = config
         self.hidden_size = config.d_model
         assert config.d_model % config.n_heads == 0
-        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
-        )
+        tensor_model_parallel_world_size = (
+            get_tensor_model_parallel_world_size())
         self.total_num_heads = self.config.n_heads
         assert self.total_num_heads % tensor_model_parallel_world_size == 0
-        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
+        self.num_heads = (self.total_num_heads //
+                          tensor_model_parallel_world_size)
         self.head_dim = self.hidden_size // self.total_num_heads
 
         # Layer norms.
@@ -158,7 +163,8 @@ class OlmoAttention(nn.Module):
 
 class OlmoMLP(nn.Module):
     """
-    This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
+    This is the MLP block where the output is computed as
+    ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
     (plus another skip connection).
     """
 
@@ -217,13 +223,16 @@ class OlmoMLP(nn.Module):
 
 class OlmoBlock(nn.Module):
     """
-    This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
+    This is a typical transformer block where the output is computed as
+    ``MLP(LN(x + Attention(LN(x))))``
     (plus another skip connection).
     """
 
-    def __init__(self,
-                 config: OLMoConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: OLMoConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         # Attention block.
         self.attn = OlmoAttention(config, linear_method)
@@ -250,27 +259,31 @@ class OlmoBlock(nn.Module):
 
 class OlmoModel(nn.Module):
 
-    def __init__(self,
-                 config: OLMoConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: OLMoConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         self.config = config
 
         self.transformer = nn.ModuleDict(
-            dict(wte=VocabParallelEmbedding(
-                config.embedding_size or config.vocab_size,
-                config.d_model,
-                linear_method=linear_method,
-            ),
-                 ln_f=nn.LayerNorm(config.d_model,
-                                   elementwise_affine=False,
-                                   bias=False),
-                 ff_out=ParallelLMHead(
-                     config.embedding_size or config.vocab_size,
-                     config.d_model,
-                     bias=config.include_bias,
-                     linear_method=linear_method,
-                 )))
+            dict(
+                wte=VocabParallelEmbedding(
+                    config.embedding_size or config.vocab_size,
+                    config.d_model,
+                    linear_method=linear_method,
+                ),
+                ln_f=nn.LayerNorm(config.d_model,
+                                  elementwise_affine=False,
+                                  bias=False),
+                ff_out=ParallelLMHead(
+                    config.embedding_size or config.vocab_size,
+                    config.d_model,
+                    bias=config.include_bias,
+                    linear_method=linear_method,
+                ),
+            ))
 
         blocks = [
             OlmoBlock(config, linear_method) for i in range(config.n_layers)
@@ -315,9 +328,11 @@ class OLMoForCausalLM(nn.Module):
     Extremely barebones HF model wrapper.
     """
 
-    def __init__(self,
-                 config: OLMoConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: OLMoConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         self.config = config
         self.linear_method = linear_method

+ 56 - 36
aphrodite/modeling/models/opt.py

@@ -18,6 +18,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only OPT model compatible with HuggingFace weights."""
+
 from typing import List, Optional, Tuple
 
 import torch
@@ -27,19 +28,25 @@ from transformers import OPTConfig
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.attention import PagedAttention
-from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
-                                              QKVParallelLinear,
-                                              ReplicatedLinear,
-                                              RowParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    ColumnParallelLinear,
+    LinearMethodBase,
+    QKVParallelLinear,
+    ReplicatedLinear,
+    RowParallelLinear,
+)
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -77,8 +84,8 @@ class OPTAttention(nn.Module):
         self.head_dim = embed_dim // total_num_heads
         self.scaling = self.head_dim**-0.5
 
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
             self.q_proj = ColumnParallelLinear(embed_dim,
                                                embed_dim,
@@ -151,7 +158,8 @@ class OPTDecoderLayer(nn.Module):
 
         self.self_attn_layer_norm = nn.LayerNorm(
             self.embed_dim,
-            elementwise_affine=config.layer_norm_elementwise_affine)
+            elementwise_affine=config.layer_norm_elementwise_affine,
+        )
         self.fc1 = ColumnParallelLinear(
             self.embed_dim,
             config.ffn_dim,
@@ -169,7 +177,8 @@ class OPTDecoderLayer(nn.Module):
         )
         self.final_layer_norm = nn.LayerNorm(
             self.embed_dim,
-            elementwise_affine=config.layer_norm_elementwise_affine)
+            elementwise_affine=config.layer_norm_elementwise_affine,
+        )
 
     def forward(
         self,
@@ -182,9 +191,11 @@ class OPTDecoderLayer(nn.Module):
         # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
         if self.do_layer_norm_before:
             hidden_states = self.self_attn_layer_norm(hidden_states)
-        hidden_states = self.self_attn(hidden_states=hidden_states,
-                                       kv_cache=kv_cache,
-                                       input_metadata=input_metadata)
+        hidden_states = self.self_attn(
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
         hidden_states = residual + hidden_states
         # 350m applies layer norm AFTER attention
         if not self.do_layer_norm_before:
@@ -218,27 +229,33 @@ class OPTDecoder(nn.Module):
         self.max_target_positions = config.max_position_embeddings
         self.vocab_size = config.vocab_size
 
-        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
-                                                   config.word_embed_proj_dim,
-                                                   linear_method=linear_method)
+        self.embed_tokens = VocabParallelEmbedding(
+            config.vocab_size,
+            config.word_embed_proj_dim,
+            linear_method=linear_method,
+        )
         # Positional embeddings are replicated (not sharded).
         self.embed_positions = OPTLearnedPositionalEmbedding(
             config.max_position_embeddings, config.hidden_size)
 
         # Project out & in will be replicated if they exist.
         if config.word_embed_proj_dim != config.hidden_size:
-            self.project_out = ReplicatedLinear(config.hidden_size,
-                                                config.word_embed_proj_dim,
-                                                bias=False,
-                                                linear_method=linear_method)
+            self.project_out = ReplicatedLinear(
+                config.hidden_size,
+                config.word_embed_proj_dim,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.project_out = None
 
         if config.word_embed_proj_dim != config.hidden_size:
-            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
-                                               config.hidden_size,
-                                               bias=False,
-                                               linear_method=linear_method)
+            self.project_in = ReplicatedLinear(
+                config.word_embed_proj_dim,
+                config.hidden_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.project_in = None
 
@@ -249,7 +266,8 @@ class OPTDecoder(nn.Module):
         if config.do_layer_norm_before and not config._remove_final_layer_norm:
             self.final_layer_norm = nn.LayerNorm(
                 config.hidden_size,
-                elementwise_affine=config.layer_norm_elementwise_affine)
+                elementwise_affine=config.layer_norm_elementwise_affine,
+            )
         else:
             self.final_layer_norm = None
 
@@ -338,19 +356,21 @@ class OPTForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters(remove_duplicate=False))
         for name, loaded_weight in hf_model_weights_iterator(
@@ -372,7 +392,7 @@ class OPTForCausalLM(nn.Module):
             if name.startswith("decoder."):
                 name = "model." + name
 
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 80 - 52
aphrodite/modeling/models/phi.py

@@ -36,6 +36,7 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 """Inference-only Phi model compatible with HuggingFace weights."""
+
 from typing import List, Optional, Tuple
 
 import torch
@@ -45,19 +46,25 @@ from transformers import PretrainedConfig
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.attention import PagedAttention
-from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
-                                              QKVParallelLinear,
-                                              RowParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    ColumnParallelLinear,
+    LinearMethodBase,
+    QKVParallelLinear,
+    RowParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -65,9 +72,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 class PhiAttention(nn.Module):
 
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         self.total_num_heads = config.num_attention_heads
         self.hidden_size = config.hidden_size
@@ -80,21 +89,27 @@ class PhiAttention(nn.Module):
                           tensor_model_parallel_world_size)
 
         # pylint: disable=C0103
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(self.hidden_size,
-                                               self.hidden_size,
-                                               bias=True,
-                                               linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(self.hidden_size,
-                                               self.hidden_size,
-                                               bias=True,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(self.hidden_size,
-                                               self.hidden_size,
-                                               bias=True,
-                                               linear_method=linear_method)
+            self.q_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.hidden_size,
+                bias=True,
+                linear_method=linear_method,
+            )
+            self.k_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.hidden_size,
+                bias=True,
+                linear_method=linear_method,
+            )
+            self.v_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.hidden_size,
+                bias=True,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.qkv_proj = QKVParallelLinear(
@@ -120,8 +135,9 @@ class PhiAttention(nn.Module):
         # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
         rope_theta = 10000
         max_position_embeddings = getattr(config, "n_positions", 2048)
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
+        is_neox_style = (True if linear_method is None
+                         or linear_method.quant_config.rope_style() is None
+                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_size,
             rotary_dim=rotary_dim,
@@ -154,9 +170,11 @@ class PhiAttention(nn.Module):
 
 class PhiMLP(nn.Module):
 
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
 
         n_inner = getattr(config, "n_inner", None)
@@ -184,9 +202,11 @@ class PhiMLP(nn.Module):
 
 class PhiLayer(nn.Module):
 
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
@@ -215,9 +235,11 @@ class PhiLayer(nn.Module):
 
 class PhiModel(nn.Module):
 
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         self.config = config
         self.linear_method = linear_method
@@ -255,19 +277,23 @@ class PhiModel(nn.Module):
 
 class PhiForCausalLM(nn.Module):
 
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         self.config = config
         self.linear_method = linear_method
 
         self.model = PhiModel(config, linear_method)
 
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      bias=True,
-                                      linear_method=linear_method)
+        self.lm_head = ParallelLMHead(
+            config.vocab_size,
+            config.hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
         self.sampler = Sampler(config.vocab_size)
 
     def forward(
@@ -291,19 +317,21 @@ class PhiForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
-            ("qkv_proj", "v_proj", "v")
+            ("qkv_proj", "v_proj", "v"),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
 
@@ -313,7 +341,7 @@ class PhiForCausalLM(nn.Module):
             if "rotary_emb.inv_freq" in name:
                 continue
 
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 68 - 46
aphrodite/modeling/models/qwen.py

@@ -4,6 +4,7 @@
 # Copyright (c) Alibaba Cloud.
 # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
 """Inference-only QWen model compatible with HuggingFace weights."""
+
 from typing import Any, Dict, List, Optional, Tuple
 
 import torch
@@ -13,20 +14,26 @@ from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
-                                              QKVParallelLinear,
-                                              RowParallelLinear,
-                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    LinearMethodBase,
+    MergedColumnParallelLinear,
+    QKVParallelLinear,
+    RowParallelLinear,
+    ColumnParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.transformers_utils.configs.qwen import QWenConfig
 
@@ -43,27 +50,35 @@ class QWenMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.w2 = ColumnParallelLinear(hidden_size,
-                                           intermediate_size,
-                                           bias=False,
-                                           linear_method=linear_method)
-            self.w1 = ColumnParallelLinear(hidden_size,
-                                           intermediate_size,
-                                           bias=False,
-                                           linear_method=linear_method)
+            self.w2 = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.w1 = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size, [intermediate_size] * 2,
+                hidden_size,
+                [intermediate_size] * 2,
                 bias=False,
-                linear_method=linear_method)
-        self.c_proj = RowParallelLinear(intermediate_size,
-                                        hidden_size,
-                                        bias=False,
-                                        linear_method=linear_method)
+                linear_method=linear_method,
+            )
+        self.c_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -94,8 +109,8 @@ class QWenAttention(nn.Module):
     ):
         super().__init__()
         self.hidden_size = hidden_size
-        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
-        )
+        tensor_model_parallel_world_size = (
+            get_tensor_model_parallel_world_size())
         self.total_num_heads = num_heads
         assert self.total_num_heads % tensor_model_parallel_world_size == 0
         self.num_heads = (self.total_num_heads //
@@ -116,8 +131,9 @@ class QWenAttention(nn.Module):
         )
         self.scaling = self.head_dim**-0.5
 
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
+        is_neox_style = (True if linear_method is None
+                         or linear_method.quant_config.rope_style() is None
+                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
@@ -157,18 +173,22 @@ class QWenBlock(nn.Module):
 
         rope_theta = getattr(config, "rope_theta", 10000)
         rope_scaling = getattr(config, "rope_scaling", None)
-        self.attn = QWenAttention(config.hidden_size,
-                                  config.num_attention_heads,
-                                  config.max_position_embeddings,
-                                  rope_theta=rope_theta,
-                                  rope_scaling=rope_scaling,
-                                  linear_method=linear_method)
+        self.attn = QWenAttention(
+            config.hidden_size,
+            config.num_attention_heads,
+            config.max_position_embeddings,
+            rope_theta=rope_theta,
+            rope_scaling=rope_scaling,
+            linear_method=linear_method,
+        )
 
         self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
 
-        self.mlp = QWenMLP(config.hidden_size,
-                           config.intermediate_size // 2,
-                           linear_method=linear_method)
+        self.mlp = QWenMLP(
+            config.hidden_size,
+            config.intermediate_size // 2,
+            linear_method=linear_method,
+        )
 
     def forward(
         self,
@@ -275,18 +295,20 @@ class QWenLMHeadModel(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "w2", 0),
             ("gate_up_proj", "w1", 1),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
@@ -294,7 +316,7 @@ class QWenLMHeadModel(nn.Module):
                 self.config):
             if "rotary_emb.inv_freq" in name:
                 continue
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 86 - 59
aphrodite/modeling/models/qwen2.py

@@ -23,6 +23,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Qwen2 model compatible with HuggingFace weights."""
+
 from typing import List, Optional, Tuple
 
 import torch
@@ -33,20 +34,26 @@ from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              ColumnParallelLinear,
-                                              MergedColumnParallelLinear,
-                                              QKVParallelLinear,
-                                              RowParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    LinearMethodBase,
+    ColumnParallelLinear,
+    MergedColumnParallelLinear,
+    QKVParallelLinear,
+    RowParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -62,27 +69,35 @@ class Qwen2MLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(hidden_size,
-                                                  intermediate_size,
-                                                  bias=False,
-                                                  linear_method=linear_method)
-            self.up_proj = ColumnParallelLinear(hidden_size,
-                                                intermediate_size,
-                                                bias=False,
-                                                linear_method=linear_method)
+            self.gate_proj = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.up_proj = ColumnParallelLinear(
+                hidden_size,
+                intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size, [intermediate_size] * 2,
+                hidden_size,
+                [intermediate_size] * 2,
                 bias=False,
-                linear_method=linear_method)
-        self.down_proj = RowParallelLinear(intermediate_size,
-                                           hidden_size,
-                                           bias=False,
-                                           linear_method=linear_method)
+                linear_method=linear_method,
+            )
+        self.down_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -102,15 +117,17 @@ class Qwen2MLP(nn.Module):
 
 class Qwen2Attention(nn.Module):
 
-    def __init__(self,
-                 hidden_size: int,
-                 num_heads: int,
-                 num_kv_heads: int,
-                 max_position: int = 4096 * 32,
-                 rope_theta: float = 10000,
-                 use_sliding_window: bool = False,
-                 linear_method: Optional[LinearMethodBase] = None,
-                 sliding_window: Optional[int] = None) -> None:
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        max_position: int = 4096 * 32,
+        rope_theta: float = 10000,
+        use_sliding_window: bool = False,
+        linear_method: Optional[LinearMethodBase] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -134,21 +151,25 @@ class Qwen2Attention(nn.Module):
         self.rope_theta = rope_theta
         self.sliding_window = sliding_window if use_sliding_window else None
 
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
             self.q_proj = ColumnParallelLinear(hidden_size,
                                                self.q_size,
                                                bias=True,
                                                linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=True,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(hidden_size,
-                                               self.kv_size,
-                                               bias=True,
-                                               linear_method=linear_method)
+            self.k_proj = ColumnParallelLinear(
+                hidden_size,
+                self.kv_size,
+                bias=True,
+                linear_method=linear_method,
+            )
+            self.v_proj = ColumnParallelLinear(
+                hidden_size,
+                self.kv_size,
+                bias=True,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.qkv_proj = QKVParallelLinear(
@@ -172,11 +193,13 @@ class Qwen2Attention(nn.Module):
             max_position=max_position,
             base=self.rope_theta,
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   self.scaling,
-                                   num_kv_heads=self.num_kv_heads,
-                                   sliding_window=self.sliding_window)
+        self.attn = PagedAttention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+            sliding_window=self.sliding_window,
+        )
 
     def forward(
         self,
@@ -212,7 +235,8 @@ class Qwen2DecoderLayer(nn.Module):
         self.hidden_size = config.hidden_size
         # Requires transformers > 4.32.0
         rope_theta = getattr(config, "rope_theta", 1000000)
-        use_sliding_window = config.use_sliding_window and layer_idx < config.max_window_layers
+        use_sliding_window = (config.use_sliding_window
+                              and layer_idx < config.max_window_layers)
         self.self_attn = Qwen2Attention(
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
@@ -221,7 +245,8 @@ class Qwen2DecoderLayer(nn.Module):
             rope_theta=rope_theta,
             use_sliding_window=use_sliding_window,
             linear_method=linear_method,
-            sliding_window=config.sliding_window)
+            sliding_window=config.sliding_window,
+        )
         self.mlp = Qwen2MLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
@@ -345,11 +370,13 @@ class Qwen2ForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -358,15 +385,15 @@ class Qwen2ForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path, cache_dir, load_format, revision):
             if "rotary_emb.inv_freq" in name:
                 continue
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 105 - 74
aphrodite/modeling/models/stablelm.py

@@ -1,5 +1,6 @@
 # coding=utf-8
-# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
+# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
+# All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,7 +17,9 @@
 # This code is based off the following work:
 # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
 # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
-"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights."""
+"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model
+compatible with HuggingFace weights."""
+
 from typing import List, Optional, Tuple
 
 import torch
@@ -26,20 +29,26 @@ from transformers import PretrainedConfig
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.attention import PagedAttention
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
-                                              QKVParallelLinear,
-                                              RowParallelLinear,
-                                              ColumnParallelLinear)
+from aphrodite.modeling.layers.linear import (
+    LinearMethodBase,
+    MergedColumnParallelLinear,
+    QKVParallelLinear,
+    RowParallelLinear,
+    ColumnParallelLinear,
+)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
+    VocabParallelEmbedding,
+    ParallelLMHead,
+)
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    get_tensor_model_parallel_world_size, )
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
+from aphrodite.modeling.hf_downloader import (
+    default_weight_loader,
+    hf_model_weights_iterator,
+)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -47,30 +56,38 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 class StablelmMLP(nn.Module):
 
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None) -> None:
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
         super().__init__()
         self.config = config
         self.hidden_size = config.hidden_size
         self.intermediate_size = config.intermediate_size
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(config.hidden_size,
-                                                  config.intermediate_size,
-                                                  bias=False,
-                                                  linear_method=linear_method)
-            self.up_proj = ColumnParallelLinear(config.hidden_size,
-                                                config.intermediate_size,
-                                                bias=False,
-                                                linear_method=linear_method)
+            self.gate_proj = ColumnParallelLinear(
+                config.hidden_size,
+                config.intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.up_proj = ColumnParallelLinear(
+                config.hidden_size,
+                config.intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.gate_up_proj = MergedColumnParallelLinear(
-                config.hidden_size, [config.intermediate_size] * 2,
+                config.hidden_size,
+                [config.intermediate_size] * 2,
                 bias=False,
-                linear_method=linear_method)
+                linear_method=linear_method,
+            )
         self.down_proj = RowParallelLinear(config.intermediate_size,
                                            config.hidden_size,
                                            bias=False)
@@ -90,9 +107,11 @@ class StablelmMLP(nn.Module):
 
 class StablelmAttention(nn.Module):
 
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None) -> None:
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
         super().__init__()
         self.config = config
         self.hidden_size = config.hidden_size
@@ -113,31 +132,38 @@ class StablelmAttention(nn.Module):
             1, self.total_num_key_value_heads // tp_size)
         self.head_dim = self.hidden_size // self.total_num_heads
         self.max_position_embeddings = config.max_position_embeddings
-        self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
+        rope_pct = self.config.partial_rotary_factor
+        self.rotary_ndims = int(self.head_dim * rope_pct)
         self.scaling = self.head_dim**-0.5
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_key_value_heads * self.head_dim
         self.qkv_bias = getattr(config, "use_qkv_bias", False)
         if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
-            raise ValueError(
-                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
-                f" and `num_heads`: {self.num_heads}).")
+            raise ValueError("hidden_size must be divisible by num_heads (got "
+                             f"`hidden_size`: {self.hidden_size}"
+                             f" and `num_heads`: {self.num_heads}).")
 
-        if linear_method is not None and not linear_method.quant_config.merge_weight(
-        ):
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
             self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(self.hidden_size,
-                                               self.q_size,
-                                               bias=self.qkv_bias,
-                                               linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(self.hidden_size,
-                                               self.kv_size,
-                                               bias=self.qkv_bias,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(self.hidden_size,
-                                               self.kv_size,
-                                               bias=self.qkv_bias,
-                                               linear_method=linear_method)
+            self.q_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.q_size,
+                bias=self.qkv_bias,
+                linear_method=linear_method,
+            )
+            self.k_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.kv_size,
+                bias=self.qkv_bias,
+                linear_method=linear_method,
+            )
+            self.v_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.kv_size,
+                bias=self.qkv_bias,
+                linear_method=linear_method,
+            )
         else:
             self.merge_weight = True
             self.qkv_proj = QKVParallelLinear(
@@ -148,24 +174,26 @@ class StablelmAttention(nn.Module):
                 self.qkv_bias,
                 linear_method=linear_method,
             )
-        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
-                                        self.hidden_size,
-                                        bias=False,
-                                        linear_method=linear_method)
-        self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
-        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
-        ) is None else linear_method.quant_config.rope_style()
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            self.hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.rotary_ndims = int(self.head_dim *
+                                self.config.partial_rotary_factor)
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.rotary_ndims,
             max_position=self.config.max_position_embeddings,
             base=self.config.rope_theta,
-            is_neox_style=is_neox_style,
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   self.scaling,
-                                   num_kv_heads=self.num_key_value_heads)
+        self.attn = PagedAttention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_key_value_heads,
+        )
 
     def forward(
         self,
@@ -200,9 +228,9 @@ class StablelmDecoderLayer(nn.Module):
         self.self_attn = StablelmAttention(config)
         self.mlp = StablelmMLP(config, linear_method)
         self.input_layernorm = nn.LayerNorm(config.hidden_size,
-                                            eps=config.norm_eps)
+                                            eps=config.layer_norm_eps)
         self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
-                                                     eps=config.norm_eps)
+                                                     eps=config.layer_norm_eps)
 
     def forward(
         self,
@@ -233,11 +261,12 @@ class StablelmDecoderLayer(nn.Module):
 
 class StableLMEpochModel(nn.Module):
 
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None) -> None:
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
         super().__init__()
-        # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size,
                                                    linear_method=linear_method)
@@ -245,7 +274,7 @@ class StableLMEpochModel(nn.Module):
             StablelmDecoderLayer(config, linear_method)
             for _ in range(config.num_hidden_layers)
         ])
-        self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
+        self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
 
     def forward(
         self,
@@ -304,11 +333,13 @@ class StablelmForCausalLM(nn.Module):
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -317,8 +348,8 @@ class StablelmForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
-        ):
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
             stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
@@ -331,7 +362,7 @@ class StablelmForCausalLM(nn.Module):
                 # Models trained using ColossalAI may include these tensors in
                 # the checkpoint. Skip them.
                 continue
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+            for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 9 - 4
aphrodite/modeling/outlines_decoding.py

@@ -8,8 +8,14 @@ from re import escape as regex_escape
 from typing import Union, Tuple
 from pydantic import BaseModel
 
-from aphrodite.endpoints.openai.protocol import CompletionRequest, ChatCompletionRequest
-from aphrodite.modeling.outlines_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor
+from aphrodite.endpoints.openai.protocol import (
+    CompletionRequest,
+    ChatCompletionRequest,
+)
+from aphrodite.modeling.outlines_logits_processors import (
+    JSONLogitsProcessor,
+    RegexLogitsProcessor,
+)
 
 
 class GuidedDecodingMode(Enum):
@@ -51,9 +57,8 @@ async def get_guided_decoding_logits_processor(
 
 
 def _get_guide_and_mode(
-    request: Union[CompletionRequest, ChatCompletionRequest]
+    request: Union[CompletionRequest, ChatCompletionRequest],
 ) -> Tuple[str, GuidedDecodingMode]:
-
     if request.guided_json:
         if not isinstance(request.guided_json, (str, dict, BaseModel)):
             raise TypeError("JSON schema must be str, dict, or BaseModel")

+ 14 - 9
aphrodite/modeling/outlines_logits_processors.py

@@ -95,20 +95,25 @@ class RegexLogitsProcessor:
 
 class JSONLogitsProcessor(RegexLogitsProcessor):
 
-    def __init__(self,
-                 schema: Union[str, Dict, BaseModel],
-                 tokenizer,
-                 whitespace_pattern: Optional[str] = None):
+    def __init__(
+        self,
+        schema: Union[str, Dict, BaseModel],
+        tokenizer,
+        whitespace_pattern: Optional[str] = None,
+    ):
         """Compile the FSM that drives the JSON-guided generation.
         Parameters
         ----------
         schema
-            A JSON schema that encodes the structure we want the model to generate
+            A JSON schema that encodes the structure we want the model to
+                generate
         tokenizer
             The model's tokenizer
         whitespace_pattern
-            Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
-            Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
+            Pattern to use for JSON syntactic whitespace
+                (doesn't impact string literals)
+            Example: allow only a single space or newline with
+                `whitespace_pattern=r"[\n ]?"`
         """
         if isinstance(schema, type(BaseModel)):
             schema_str = json.dumps(schema.model_json_schema())
@@ -119,7 +124,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
         else:
             raise ValueError(
                 f"Cannot parse schema {schema}. The schema must be either " +
-                "a Pydantic object, a dictionary or a string that contains the JSON "
-                + "Schema specification")
+                "a Pydantic object, a dictionary or a string that contains " +
+                "the JSON Schema specification")
         regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
         super().__init__(regex_string, tokenizer)

+ 45 - 27
aphrodite/processing/block_manager.py

@@ -1,4 +1,5 @@
 """A block manager that manages token blocks."""
+
 import enum
 from itertools import count
 from os.path import commonprefix
@@ -18,12 +19,14 @@ class BlockAllocator:
     the reference count becomes zero, the block is added back to the free list.
     """
 
-    def __init__(self,
-                 device: Device,
-                 block_size: int,
-                 num_blocks: int,
-                 eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
-                 enable_caching: bool = False) -> None:
+    def __init__(
+        self,
+        device: Device,
+        block_size: int,
+        num_blocks: int,
+        eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
+        enable_caching: bool = False,
+    ) -> None:
         self.device = device
         self.block_size = block_size
         self.num_blocks = num_blocks
@@ -46,11 +49,13 @@ class BlockAllocator:
             block.block_hash = block_hash
             block.num_hashed_tokens = num_hashed_tokens
             return block
-        block = PhysicalTokenBlock(device=self.device,
-                                   block_number=self.current_num_blocks,
-                                   block_size=self.block_size,
-                                   block_hash=block_hash,
-                                   num_hashed_tokens=num_hashed_tokens)
+        block = PhysicalTokenBlock(
+            device=self.device,
+            block_number=self.current_num_blocks,
+            block_size=self.block_size,
+            block_hash=block_hash,
+            num_hashed_tokens=num_hashed_tokens,
+        )
         self.current_num_blocks += 1
         return block
 
@@ -95,13 +100,15 @@ class BlockAllocator:
                 del self.cached_blocks[block.block_hash]
 
     def get_num_free_blocks(self) -> int:
-        return self.num_blocks - self.current_num_blocks + self.evictor.num_blocks
+        return (self.num_blocks - self.current_num_blocks +
+                self.evictor.num_blocks)
 
     def contains_block(self, block_hash: int) -> bool:
         return block_hash in self.cached_blocks or block_hash in self.evictor
 
     def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
-        # If caching is enabled, update the hash of block and the cached_blocks dictionary.
+        # If caching is enabled, update the hash of block and the cached_blocks
+        # dictionary.
         if self.enable_caching:
             assert not self.contains_block(block_hash)
             old_hash = block.block_hash
@@ -119,6 +126,7 @@ class AllocStatus(enum.Enum):
     3. Never: seq_group can never be allocated.
       The seq_group is too large to allocated in GPU.
     """
+
     OK = enum.auto()
     LATER = enum.auto()
     NEVER = enum.auto()
@@ -142,8 +150,10 @@ class BlockSpaceManager:
 
         self.block_sliding_window = None
         if sliding_window is not None:
-            assert sliding_window % block_size == 0, (sliding_window,
-                                                      block_size)
+            assert sliding_window % block_size == 0, (
+                sliding_window,
+                block_size,
+            )
             self.block_sliding_window = sliding_window // block_size
 
         self.watermark = watermark
@@ -152,14 +162,18 @@ class BlockSpaceManager:
         self.enable_caching = enable_caching
 
         self.watermark_blocks = int(watermark * num_gpu_blocks)
-        self.gpu_allocator = BlockAllocator(Device.GPU,
-                                            block_size,
-                                            num_gpu_blocks,
-                                            enable_caching=enable_caching)
-        self.cpu_allocator = BlockAllocator(Device.CPU,
-                                            block_size,
-                                            num_cpu_blocks,
-                                            enable_caching=enable_caching)
+        self.gpu_allocator = BlockAllocator(
+            Device.GPU,
+            block_size,
+            num_gpu_blocks,
+            enable_caching=enable_caching,
+        )
+        self.cpu_allocator = BlockAllocator(
+            Device.CPU,
+            block_size,
+            num_cpu_blocks,
+            enable_caching=enable_caching,
+        )
         # Mapping: seq_id -> BlockTable.
         self.block_tables: Dict[int, BlockTable] = {}
 
@@ -199,7 +213,8 @@ class BlockSpaceManager:
             else:
                 block = self.gpu_allocator.allocate(
                     seq.hash_of_block(logical_idx),
-                    seq.num_hashed_tokens_of_block(logical_idx))
+                    seq.num_hashed_tokens_of_block(logical_idx),
+                )
             block_table.append(block)
 
         # Assign the block table for each sequence.
@@ -218,10 +233,12 @@ class BlockSpaceManager:
         seq: Sequence,
         last_block: PhysicalTokenBlock,
     ) -> PhysicalTokenBlock:
-        # Compute a new hash for the block so that it can be shared by other Sequences
+        # Compute a new hash for the block so that it can be shared by other
+        # Sequences
         new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
 
-        # if new_hash is already in the cached table, then free last_block and return the cached version
+        # if new_hash is already in the cached table, then free last_block and
+        # return the cached version
         if self.gpu_allocator.contains_block(new_hash):
             self.gpu_allocator.free(last_block)
             return self.gpu_allocator.allocate(new_hash)
@@ -289,7 +306,8 @@ class BlockSpaceManager:
         assert last_block.device == Device.GPU
         if last_block.ref_count == 1:
             # Not shared with other sequences. Appendable.
-            # If the last block is now complete, promote it to a full block so that it can be shared
+            # If the last block is now complete, promote it to a full block so
+            # that it can be shared
             new_block = self._maybe_promote_last_block(seq, last_block)
             block_table[-1] = new_block
             return None

+ 3 - 3
aphrodite/processing/evictor.py

@@ -39,9 +39,9 @@ class Evictor(ABC):
     @abstractmethod
     def remove(self, block_hash: int) -> PhysicalTokenBlock:
         """Simply removes the block with the hash value block_hash from the
-        evictor. Caller is responsible for making sure that block_hash is contained
-        in the evictor before calling remove. Should be used to "bring back" blocks
-        that have been freed but not evicted yet.
+        evictor. Caller is responsible for making sure that block_hash is
+        contained in the evictor before calling remove. Should be used to
+        "bring back" blocks that have been freed but not evicted yet.
         """
         pass
 

+ 27 - 15
aphrodite/processing/scheduler.py

@@ -8,8 +8,13 @@ from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
 from aphrodite.processing.block_manager import AllocStatus, BlockSpaceManager
 from aphrodite.processing.policy import PolicyFactory
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
-                                       SequenceGroupMetadata, SequenceStatus)
+from aphrodite.common.sequence import (
+    Sequence,
+    SequenceData,
+    SequenceGroup,
+    SequenceGroupMetadata,
+    SequenceStatus,
+)
 
 
 class PreemptionMode(enum.Enum):
@@ -21,6 +26,7 @@ class PreemptionMode(enum.Enum):
     recompute them when the sequences are resumed, treating the sequences as
     new prompts.
     """
+
     SWAP = enum.auto()
     RECOMPUTE = enum.auto()
 
@@ -59,8 +65,11 @@ class SchedulerOutputs:
     def _sort_by_lora_ids(self) -> bool:
         self.scheduled_seq_groups = sorted(
             self.scheduled_seq_groups,
-            key=lambda g: (g.lora_request.lora_int_id
-                           if g.lora_request else 0, g.request_id))
+            key=lambda g: (
+                g.lora_request.lora_int_id if g.lora_request else 0,
+                g.request_id,
+            ),
+        )
 
     @property
     def lora_requests(self) -> Set[LoRARequest]:
@@ -82,8 +91,10 @@ class Scheduler:
         # LoRAs. This should be improved in the future.
         self.lora_config = lora_config
 
-        self.prompt_limit = min(self.scheduler_config.max_model_len,
-                                self.scheduler_config.max_num_batched_tokens)
+        self.prompt_limit = min(
+            self.scheduler_config.max_model_len,
+            self.scheduler_config.max_num_batched_tokens,
+        )
 
         # Instantiate the scheduling policy.
         self.policy = PolicyFactory.get_policy(policy_name="fcfs")
@@ -93,7 +104,8 @@ class Scheduler:
             num_gpu_blocks=self.cache_config.num_gpu_blocks,
             num_cpu_blocks=self.cache_config.num_cpu_blocks,
             sliding_window=self.cache_config.sliding_window,
-            enable_caching=self.cache_config.context_shift)
+            enable_caching=self.cache_config.context_shift,
+        )
 
         # Sequence groups in the WAITING state.
         self.waiting: Deque[SequenceGroup] = deque()
@@ -169,9 +181,9 @@ class Scheduler:
             # requests in the generation phase.
             num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
                                 for seq_group in self.running)
-            curr_loras = set(
+            curr_loras = (set(
                 seq_group.lora_int_id
-                for seq_group in self.running) if self.lora_enabled else None
+                for seq_group in self.running) if self.lora_enabled else None)
             seq_lens: List[int] = []
 
             # Optimization: We do not sort the waiting queue since the preempted
@@ -213,8 +225,8 @@ class Scheduler:
                 lora_int_id = 0
                 if self.lora_enabled:
                     lora_int_id = seq_group.lora_int_id
-                    if lora_int_id > 0 and lora_int_id not in curr_loras and len(
-                            curr_loras) >= self.lora_config.max_loras:
+                    if (lora_int_id > 0 and lora_int_id not in curr_loras
+                            and len(curr_loras) >= self.lora_config.max_loras):
                         # We don't have a space for another LoRA, so
                         # we ignore this request for now.
                         leftover_waiting_sequences.appendleft(seq_group)
@@ -297,9 +309,9 @@ class Scheduler:
         if not preempted:
             num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
                                 for seq_group in self.running)
-            curr_loras = set(
+            curr_loras = (set(
                 seq_group.lora_int_id
-                for seq_group in self.running) if self.lora_enabled else None
+                for seq_group in self.running) if self.lora_enabled else None)
 
             leftover_swapped = deque()
 
@@ -308,8 +320,8 @@ class Scheduler:
                 lora_int_id = 0
                 if self.lora_enabled:
                     lora_int_id = seq_group.lora_int_id
-                    if lora_int_id > 0 and lora_int_id not in curr_loras and len(
-                            curr_loras) >= self.lora_config.max_loras:
+                    if (lora_int_id > 0 and lora_int_id not in curr_loras
+                            and len(curr_loras) >= self.lora_config.max_loras):
                         # We don't have a space for another LoRA, so
                         # we ignore this request for now.
                         leftover_swapped.appendleft(seq_group)

+ 208 - 135
aphrodite/task_handler/model_runner.py

@@ -7,19 +7,28 @@ import torch
 import torch.nn as nn
 from loguru import logger
 
-from aphrodite.common.config import (DeviceConfig, ModelConfig, LoRAConfig,
-                                     ParallelConfig, SchedulerConfig)
+from aphrodite.common.config import (
+    DeviceConfig,
+    ModelConfig,
+    LoRAConfig,
+    ParallelConfig,
+    SchedulerConfig,
+)
 from aphrodite.common.logger import get_loading_progress_bar
 from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
 from aphrodite.modeling.megatron import cupy_utils
-from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
-                                                          )
+from aphrodite.modeling.megatron.communication_op import broadcast_tensor_dict
 from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size, with_cupy_nccl_for_all_reduce)
+    get_tensor_model_parallel_world_size,
+    with_cupy_nccl_for_all_reduce,
+)
 from aphrodite.modeling.megatron import custom_all_reduce
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
-from aphrodite.common.sequence import (SamplerOutput, SequenceData,
-                                       SequenceGroupMetadata)
+from aphrodite.common.sequence import (
+    SamplerOutput,
+    SequenceData,
+    SequenceGroupMetadata,
+)
 from aphrodite.modeling.sampling_metadata import PersistentMetadata
 from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
 from aphrodite.lora.layers import LoRAMapping
@@ -80,9 +89,9 @@ class ModelRunner:
         # cache in_wsl result
         self.in_wsl = in_wsl()
         self.kv_cache_dtype = kv_cache_dtype
-        self.kv_quant_params = self.load_kv_quant_params(
-            model_config,
-            kv_quant_params_path) if self.kv_cache_dtype == "int8" else None
+        self.kv_quant_params = (self.load_kv_quant_params(
+            model_config, kv_quant_params_path)
+                                if self.kv_cache_dtype == "int8" else None)
 
     def load_kv_quant_params(self, model_config: ModelConfig,
                              kv_quant_params_path: str) -> List[List[float]]:
@@ -93,14 +102,16 @@ class ModelRunner:
         for arch in architectures:
             if arch not in ["LlamaForCausalLM", "LLaMAForCausalLM"]:
                 raise ValueError(
-                    f"KV CACHE INT8 is not supported for model architectures {arch} for now. "
-                    f"Supported architectures: LlamaForCausalLM and LLaMAForCausalLM."
-                )
+                    "KV CACHE INT8 is not supported for model architectures "
+                    f"{arch} for now. "
+                    "Supported architectures: LlamaForCausalLM and "
+                    "LLaMAForCausalLM.")
         num_layers = model_config.hf_config.num_hidden_layers
         kv_quant_params = []
         for i in range(num_layers):
             if kv_quant_params_path is not None:
-                path = kv_quant_params_path + f"/layers.{i}.past_kv_scale.0.weight"
+                path = (kv_quant_params_path +
+                        f"/layers.{i}.past_kv_scale.0.weight")
                 kv_quant_param = list(np.fromfile(path, dtype=np.float32))
             kv_quant_params.append(kv_quant_param)
         return kv_quant_params
@@ -119,9 +130,9 @@ class ModelRunner:
         vocab_size = self.model.config.vocab_size
 
         if self.lora_config:
-            assert hasattr(
-                self.model, "supported_lora_modules"
-            ) and self.model.supported_lora_modules, "Model does not support LoRA"
+            assert (hasattr(self.model, "supported_lora_modules")
+                    and self.model.supported_lora_modules
+                    ), "Model does not support LoRA"
             assert hasattr(
                 self.model,
                 "embedding_modules"), "Model does not have embedding_modules"
@@ -130,9 +141,13 @@ class ModelRunner:
             self.lora_manager = LRUCacheWorkerLoRAManager(
                 self.scheduler_config.max_num_seqs,
                 self.scheduler_config.max_num_batched_tokens +
-                self.scheduler_config.max_paddings, vocab_size,
-                self.lora_config, self.device, self.model.embedding_modules,
-                self.model.embedding_padding_modules)
+                self.scheduler_config.max_paddings,
+                vocab_size,
+                self.lora_config,
+                self.device,
+                self.model.embedding_modules,
+                self.model.embedding_padding_modules,
+            )
             self.model = self.lora_manager.create_lora_manager(self.model)
 
     def set_block_size(self, block_size: int) -> None:
@@ -147,7 +162,7 @@ class ModelRunner:
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
     ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
-               List[int], List[int], Set[LoRARequest]]:
+               List[int], List[int], Set[LoRARequest], ]:
         assert len(seq_group_metadata_list) > 0
         input_tokens: List[List[int]] = []
         input_positions: List[List[int]] = []
@@ -174,8 +189,9 @@ class ModelRunner:
 
             # NOTE: This only works for oooooooxxx style attention.
             computed_block_nums = seq_group_metadata.computed_block_nums
-            if computed_block_nums is not None and len(
-                    computed_block_nums) > 0 and self.sliding_window is None:
+            if (computed_block_nums is not None
+                    and len(computed_block_nums) > 0
+                    and self.sliding_window is None):
                 # Prefix is not supported with sliding_window
                 computed_len = len(computed_block_nums) * self.block_size
                 prompt_tokens = prompt_tokens[computed_len:]
@@ -234,21 +250,27 @@ class ModelRunner:
                 slot_mapping[-1].append(slot)
 
         max_prompt_len = max(subquery_lens)
-        input_tokens = _make_tensor_with_pad(input_tokens,
-                                             max_prompt_len,
-                                             pad=0,
-                                             dtype=torch.long,
-                                             device=self.device)
-        input_positions = _make_tensor_with_pad(input_positions,
-                                                max_prompt_len,
-                                                pad=0,
-                                                dtype=torch.long,
-                                                device=self.device)
-        slot_mapping = _make_tensor_with_pad(slot_mapping,
-                                             max_prompt_len,
-                                             pad=_PAD_SLOT_ID,
-                                             dtype=torch.long,
-                                             device=self.device)
+        input_tokens = _make_tensor_with_pad(
+            input_tokens,
+            max_prompt_len,
+            pad=0,
+            dtype=torch.long,
+            device=self.device,
+        )
+        input_positions = _make_tensor_with_pad(
+            input_positions,
+            max_prompt_len,
+            pad=0,
+            dtype=torch.long,
+            device=self.device,
+        )
+        slot_mapping = _make_tensor_with_pad(
+            slot_mapping,
+            max_prompt_len,
+            pad=_PAD_SLOT_ID,
+            dtype=torch.long,
+            device=self.device,
+        )
         lora_index_mapping = [
             _pad_to_max(mapping, max_prompt_len, pad=0)
             for mapping in lora_index_mapping
@@ -265,11 +287,13 @@ class ModelRunner:
             dtype=torch.int,
             device=self.device,
         )
-        start_loc_tensor = torch.arange(0,
-                                        len(prompt_lens) * max_prompt_len,
-                                        max_prompt_len,
-                                        dtype=torch.long,
-                                        device=self.device)
+        start_loc_tensor = torch.arange(
+            0,
+            len(prompt_lens) * max_prompt_len,
+            max_prompt_len,
+            dtype=torch.long,
+            device=self.device,
+        )
         prompt_lens_tensor = torch.tensor(prompt_lens,
                                           dtype=torch.long,
                                           device=self.device)
@@ -287,15 +311,22 @@ class ModelRunner:
             kv_cache_dtype=self.kv_cache_dtype,
             kv_quant_params=self.kv_quant_params,
         )
-        return (input_tokens, input_positions, input_metadata, prompt_lens,
-                subquery_lens, lora_index_mapping, lora_prompt_mapping,
-                lora_requests)
+        return (
+            input_tokens,
+            input_positions,
+            input_metadata,
+            prompt_lens,
+            subquery_lens,
+            lora_index_mapping,
+            lora_prompt_mapping,
+            lora_requests,
+        )
 
     def _prepare_decode(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
     ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
-               Set[LoRARequest]]:
+               Set[LoRARequest], ]:
         assert len(seq_group_metadata_list) > 0
         input_tokens: List[List[int]] = []
         input_positions: List[List[int]] = []
@@ -324,8 +355,8 @@ class ModelRunner:
                 position = seq_len - 1
                 input_positions.append([position])
 
-                context_len = seq_len if self.sliding_window is None else min(
-                    seq_len, self.sliding_window)
+                context_len = (seq_len if self.sliding_window is None else min(
+                    seq_len, self.sliding_window))
                 context_lens.append(context_len)
 
                 block_table = seq_group_metadata.block_tables[seq_id]
@@ -366,16 +397,20 @@ class ModelRunner:
                                              pad=0,
                                              dtype=torch.long,
                                              device=self.device)
-        input_positions = _make_tensor_with_pad(input_positions,
-                                                max_len=1,
-                                                pad=0,
-                                                dtype=torch.long,
-                                                device=self.device)
-        slot_mapping = _make_tensor_with_pad(slot_mapping,
-                                             max_len=1,
-                                             pad=_PAD_SLOT_ID,
-                                             dtype=torch.long,
-                                             device=self.device)
+        input_positions = _make_tensor_with_pad(
+            input_positions,
+            max_len=1,
+            pad=0,
+            dtype=torch.long,
+            device=self.device,
+        )
+        slot_mapping = _make_tensor_with_pad(
+            slot_mapping,
+            max_len=1,
+            pad=_PAD_SLOT_ID,
+            dtype=torch.long,
+            device=self.device,
+        )
         context_lens = torch.tensor(context_lens,
                                     dtype=torch.int,
                                     device=self.device)
@@ -416,8 +451,14 @@ class ModelRunner:
             kv_cache_dtype=self.kv_cache_dtype,
             kv_quant_params=self.kv_quant_params,
         )
-        return (input_tokens, input_positions, input_metadata,
-                lora_index_mapping, lora_prompt_mapping, lora_requests)
+        return (
+            input_tokens,
+            input_positions,
+            input_metadata,
+            lora_index_mapping,
+            lora_prompt_mapping,
+            lora_requests,
+        )
 
     def _prepare_sample(
         self,
@@ -453,8 +494,10 @@ class ModelRunner:
 
                 if sampling_params.prompt_logprobs is not None:
                     selected_token_indices.extend(
-                        range(selected_token_start_idx,
-                              selected_token_start_idx + subquery_len - 1))
+                        range(
+                            selected_token_start_idx,
+                            selected_token_start_idx + subquery_len - 1,
+                        ))
                 selected_token_indices.append(selected_token_start_idx +
                                               subquery_len - 1)
                 selected_token_start_idx += max_subquery_len
@@ -464,28 +507,36 @@ class ModelRunner:
             else:
                 num_seqs = len(seq_ids)
                 selected_token_indices.extend(
-                    range(selected_token_start_idx,
-                          selected_token_start_idx + num_seqs))
+                    range(
+                        selected_token_start_idx,
+                        selected_token_start_idx + num_seqs,
+                    ))
                 selected_token_start_idx += num_seqs
 
                 categorized_sample_indices[
                     sampling_params.sampling_type].extend(
-                        range(categorized_sample_indices_start_idx,
-                              categorized_sample_indices_start_idx + num_seqs))
+                        range(
+                            categorized_sample_indices_start_idx,
+                            categorized_sample_indices_start_idx + num_seqs,
+                        ))
                 categorized_sample_indices_start_idx += num_seqs
 
             if sampling_params.seed is not None:
                 generators.append(seq_group_metadata.state.generator)
 
-        selected_token_indices = _async_h2d(selected_token_indices,
-                                            dtype=torch.long,
-                                            target_device=self.device,
-                                            pin_memory=not self.in_wsl)
+        selected_token_indices = _async_h2d(
+            selected_token_indices,
+            dtype=torch.long,
+            target_device=self.device,
+            pin_memory=not self.in_wsl,
+        )
         categorized_sample_indices = {
-            t: _async_h2d(seq_ids,
-                          dtype=torch.int,
-                          target_device=self.device,
-                          pin_memory=not self.in_wsl)
+            t: _async_h2d(
+                seq_ids,
+                dtype=torch.int,
+                target_device=self.device,
+                pin_memory=not self.in_wsl,
+            )
             for t, seq_ids in categorized_sample_indices.items()
         }
 
@@ -512,20 +563,32 @@ class ModelRunner:
         self,
         seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
     ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
-               Set[int], LoRAMapping]:
+               Set[int], LoRAMapping, ]:
         if self.is_driver_worker:
             # NOTE: We assume that all sequences in the group are all prompts or
             # all decodes.
             is_prompt = seq_group_metadata_list[0].is_prompt
             # Prepare input tensors.
             if is_prompt:
-                (input_tokens, input_positions, input_metadata, prompt_lens,
-                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
-                 lora_requests) = self._prepare_prompt(seq_group_metadata_list)
+                (
+                    input_tokens,
+                    input_positions,
+                    input_metadata,
+                    prompt_lens,
+                    subquery_lens,
+                    lora_index_mapping,
+                    lora_prompt_mapping,
+                    lora_requests,
+                ) = self._prepare_prompt(seq_group_metadata_list)
             else:
-                (input_tokens, input_positions, input_metadata,
-                 lora_index_mapping, lora_prompt_mapping,
-                 lora_requests) = self._prepare_decode(seq_group_metadata_list)
+                (
+                    input_tokens,
+                    input_positions,
+                    input_metadata,
+                    lora_index_mapping,
+                    lora_prompt_mapping,
+                    lora_requests,
+                ) = self._prepare_decode(seq_group_metadata_list)
                 prompt_lens = []
                 subquery_lens = None
             sampling_metadata = self._prepare_sample(seq_group_metadata_list,
@@ -559,7 +622,7 @@ class ModelRunner:
                 "kv_cache_dtype": input_metadata.kv_cache_dtype,
                 "kv_quant_params": input_metadata.kv_quant_params,
                 "selected_token_indices":
-                sampling_metadata.selected_token_indices,
+                sampling_metadata.selected_token_indices,  # noqa
                 "lora_requests": lora_requests,
                 "lora_mapping": lora_mapping,
             }
@@ -593,8 +656,14 @@ class ModelRunner:
                 perform_sampling=False,
             )
 
-        return (input_tokens, input_positions, input_metadata,
-                sampling_metadata, lora_requests, lora_mapping)
+        return (
+            input_tokens,
+            input_positions,
+            input_metadata,
+            sampling_metadata,
+            lora_requests,
+            lora_mapping,
+        )
 
     @torch.inference_mode()
     def execute_model(
@@ -602,9 +671,14 @@ class ModelRunner:
         seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
         kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
     ) -> Optional[SamplerOutput]:
-        (input_tokens, input_positions, input_metadata, sampling_metadata,
-         lora_requests,
-         lora_mapping) = (self.prepare_input_tensors(seq_group_metadata_list))
+        (
+            input_tokens,
+            input_positions,
+            input_metadata,
+            sampling_metadata,
+            lora_requests,
+            lora_mapping,
+        ) = self.prepare_input_tensors(seq_group_metadata_list)
 
         if self.lora_config:
             self.set_active_loras(lora_requests, lora_mapping)
@@ -663,8 +737,8 @@ class ModelRunner:
         # number of tokens equal to max_num_batched_tokens.
         seqs: List[SequenceGroupMetadata] = []
         for group_id in range(max_num_seqs):
-            seq_len = (max_num_batched_tokens // max_num_seqs +
-                       (group_id < max_num_batched_tokens % max_num_seqs))
+            seq_len = max_num_batched_tokens // max_num_seqs + (
+                group_id < max_num_batched_tokens % max_num_seqs)
             seq_data = SequenceData([0] * seq_len)
             seq = SequenceGroupMetadata(
                 request_id=str(group_id),
@@ -756,45 +830,44 @@ class ModelRunner:
         task = progress.add_task("[cyan]Capturing graph...",
                                  total=len(batch_size_capture_list))
 
-        with progress:
-            with custom_all_reduce.capture():
-                for batch_size in reversed(batch_size_capture_list):
-                    if batch_size > self.scheduler_config.max_num_seqs:
-                        continue
-                    # Create dummy input_metadata.
-                    input_metadata = InputMetadata(
-                        is_prompt=False,
-                        slot_mapping=slot_mapping[:batch_size],
-                        prompt_lens=None,
-                        max_seq_len=None,
-                        start_loc=None,
-                        max_context_len=self.max_context_len_to_capture,
-                        context_lens=context_lens[:batch_size],
-                        block_tables=block_tables[:batch_size],
-                        use_cuda_graph=True,
-                        kv_cache_dtype=self.kv_cache_dtype,
-                        kv_quant_params=self.kv_quant_params,
-                    )
+        with progress, custom_all_reduce.capture():
+            for batch_size in reversed(batch_size_capture_list):
+                if batch_size > self.scheduler_config.max_num_seqs:
+                    continue
+                # Create dummy input_metadata.
+                input_metadata = InputMetadata(
+                    is_prompt=False,
+                    slot_mapping=slot_mapping[:batch_size],
+                    prompt_lens=None,
+                    max_seq_len=None,
+                    start_loc=None,
+                    max_context_len=self.max_context_len_to_capture,
+                    context_lens=context_lens[:batch_size],
+                    block_tables=block_tables[:batch_size],
+                    use_cuda_graph=True,
+                    kv_cache_dtype=self.kv_cache_dtype,
+                    kv_quant_params=self.kv_quant_params,
+                )
 
-                    if self.lora_config:
-                        lora_mapping = LoRAMapping(
-                            [0] * batch_size,
-                            [0] * batch_size,
-                        )
-                        self.set_active_loras(set(), lora_mapping)
-
-                    graph_runner = CUDAGraphRunner(self.model)
-                    graph_runner.capture(
-                        input_tokens[:batch_size],
-                        input_positions[:batch_size],
-                        kv_caches,
-                        input_metadata,
-                        memory_pool=self.graph_memory_pool,
+                if self.lora_config:
+                    lora_mapping = LoRAMapping(
+                        [0] * batch_size,
+                        [0] * batch_size,
                     )
-                    self.graph_memory_pool = graph_runner.graph.pool()
-                    self.graph_runners[batch_size] = graph_runner
-                    # Update the progress bar
-                    progress.update(task, advance=1)
+                    self.set_active_loras(set(), lora_mapping)
+
+                graph_runner = CUDAGraphRunner(self.model)
+                graph_runner.capture(
+                    input_tokens[:batch_size],
+                    input_positions[:batch_size],
+                    kv_caches,
+                    input_metadata,
+                    memory_pool=self.graph_memory_pool,
+                )
+                self.graph_memory_pool = graph_runner.graph.pool()
+                self.graph_runners[batch_size] = graph_runner
+                # Update the progress bar
+                progress.update(task, advance=1)
         end_time = time.perf_counter()
         elapsed_time = end_time - start_time
         # This usually takes < 10 seconds.
@@ -842,14 +915,14 @@ class CUDAGraphRunner:
         # NOTE: Python 3.8 does not support multi-line with statements.
         # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
         self.graph = torch.cuda.CUDAGraph()
-        with torch.cuda.graph(self.graph, pool=memory_pool):
-            with _maybe_cupy_nccl():
-                hidden_states = self.model(
-                    input_ids,
-                    positions,
-                    kv_caches,
-                    input_metadata,
-                )
+        with torch.cuda.graph(self.graph,
+                              pool=memory_pool), _maybe_cupy_nccl():
+            hidden_states = self.model(
+                input_ids,
+                positions,
+                kv_caches,
+                input_metadata,
+            )
         torch.cuda.synchronize()
 
         # Save the input and output buffers.

+ 2 - 1
aphrodite/task_handler/worker.py

@@ -86,7 +86,8 @@ class Worker:
             # This env var set by Ray causes exceptions with graph building.
             os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
 
-            # Patch for torch.cuda.is_available() unexpected error in WSL; always call torch.cuda.device_count() before initialising device
+            # Patch for torch.cuda.is_available() unexpected error in WSL;
+            # always call torch.cuda.device_count() before initialising device
             if in_wsl():
                 torch.cuda.device_count()
             self.device = torch.device(f"cuda:{self.local_rank}")

+ 108 - 151
aphrodite/transformers_utils/configs/mpt.py

@@ -2,123 +2,70 @@
 # Copied from
 # https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py
 """A HuggingFace-style model configuration."""
+
 import warnings
 from typing import Any, Dict, Optional, Union
 
 from transformers import PretrainedConfig
 
 attn_config_defaults: Dict = {
-    'attn_type': 'multihead_attention',
-    'attn_pdrop': 0.0,
-    'attn_impl': 'triton',
-    'qk_ln': False,
-    'clip_qkv': None,
-    'softmax_scale': None,
-    'prefix_lm': False,
-    'attn_uses_sequence_id': False,
-    'alibi': False,
-    'alibi_bias_max': 8
+    "attn_type": "multihead_attention",
+    "attn_pdrop": 0.0,
+    "attn_impl": "triton",
+    "qk_ln": False,
+    "clip_qkv": None,
+    "softmax_scale": None,
+    "prefix_lm": False,
+    "attn_uses_sequence_id": False,
+    "alibi": False,
+    "alibi_bias_max": 8,
 }
-ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
+ffn_config_defaults: Dict = {"ffn_type": "mptmlp"}
 init_config_defaults: Dict = {
-    'name': 'kaiming_normal_',
-    'fan_mode': 'fan_in',
-    'init_nonlinearity': 'relu',
-    'init_div_is_residual': True,
-    'emb_init_std': None,
-    'emb_init_uniform_lim': None,
-    'init_std': None,
-    'init_gain': 0.0
+    "name": "kaiming_normal_",
+    "fan_mode": "fan_in",
+    "init_nonlinearity": "relu",
+    "init_div_is_residual": True,
+    "emb_init_std": None,
+    "emb_init_uniform_lim": None,
+    "init_std": None,
+    "init_gain": 0.0,
 }
 
 
 class MPTConfig(PretrainedConfig):
-    model_type = 'mpt'
+    model_type = "mpt"
     attribute_map = {
-        'num_attention_heads': 'n_heads',
-        'hidden_size': 'd_model',
-        'num_hidden_layers': 'n_layers',
+        "num_attention_heads": "n_heads",
+        "hidden_size": "d_model",
+        "num_hidden_layers": "n_layers",
     }
 
     # pylint: disable=dangerous-default-value
-    def __init__(self,
-                 d_model: int = 2048,
-                 n_heads: int = 16,
-                 n_layers: int = 24,
-                 expansion_ratio: int = 4,
-                 max_seq_len: int = 2048,
-                 vocab_size: int = 50368,
-                 resid_pdrop: float = 0.0,
-                 emb_pdrop: float = 0.0,
-                 learned_pos_emb: bool = True,
-                 attn_config: Dict = attn_config_defaults,
-                 ffn_config: Dict = ffn_config_defaults,
-                 init_device: str = 'cpu',
-                 logit_scale: Optional[Union[float, str]] = None,
-                 no_bias: bool = False,
-                 embedding_fraction: float = 1.0,
-                 norm_type: str = 'low_precision_layernorm',
-                 use_cache: bool = False,
-                 init_config: Dict = init_config_defaults,
-                 fc_type: str = 'torch',
-                 verbose: Optional[int] = None,
-                 **kwargs: Any):
-        """The MPT configuration class.
-        Args:
-            d_model (int): The size of the embedding dimension of the model.
-            n_heads (int): The number of attention heads.
-            n_layers (int): The number of layers in the model.
-            expansion_ratio (int): The ratio of the up/down scale in the ffn.
-            max_seq_len (int): The maximum sequence length of the model.
-            vocab_size (int): The size of the vocabulary.
-            resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
-            emb_pdrop (float): The dropout probability for the embedding layer.
-            learned_pos_emb (bool): Whether to use learned positional embeddings
-            attn_config (Dict): A dictionary used to configure the model's attention module:
-                attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
-                attn_pdrop (float): The dropout probability for the attention layers.
-                attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
-                qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
-                clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
-                    this value.
-                softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
-                    use the default scale of ``1/sqrt(d_keys)``.
-                prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
-                    extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
-                    can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
-                attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
-                    When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
-                    which sub-sequence each token belongs to.
-                    Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
-                alibi (bool): Whether to use the alibi bias instead of position embeddings.
-                alibi_bias_max (int): The maximum value of the alibi bias.
-                kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
-            ffn_config (Dict): A dictionary used to configure the model's ffn module:
-                ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp
-            init_device (str): The device to use for parameter initialization.
-            logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
-            no_bias (bool): Whether to use bias in all layers.
-            verbose (int): The verbosity level. 0 is silent.
-            embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
-            norm_type (str): choose type of norm to use
-            use_cache (bool): Whether or not the model should return the last key/values attentions
-            init_config (Dict): A dictionary used to configure the model initialization:
-                init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
-                    'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
-                    'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
-                init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
-                emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
-                emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
-                    used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
-                init_std (float): The standard deviation of the normal distribution used to initialize the model,
-                    if using the baseline_ parameter initialization scheme.
-                init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
-                fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
-                init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
-                ---
-                See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
-            fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
-        """
+    def __init__(
+        self,
+        d_model: int = 2048,
+        n_heads: int = 16,
+        n_layers: int = 24,
+        expansion_ratio: int = 4,
+        max_seq_len: int = 2048,
+        vocab_size: int = 50368,
+        resid_pdrop: float = 0.0,
+        emb_pdrop: float = 0.0,
+        learned_pos_emb: bool = True,
+        attn_config: Dict = attn_config_defaults,
+        ffn_config: Dict = ffn_config_defaults,
+        init_device: str = "cpu",
+        logit_scale: Optional[Union[float, str]] = None,
+        no_bias: bool = False,
+        embedding_fraction: float = 1.0,
+        norm_type: str = "low_precision_layernorm",
+        use_cache: bool = False,
+        init_config: Dict = init_config_defaults,
+        fc_type: str = "torch",
+        verbose: Optional[int] = None,
+        **kwargs: Any,
+    ):
         self.d_model = d_model
         self.n_heads = n_heads
         self.n_layers = n_layers
@@ -139,26 +86,30 @@ class MPTConfig(PretrainedConfig):
         self.init_config = init_config
         self.fc_type = fc_type
         if verbose is not None:
-            warnings.warn(DeprecationWarning(
-                'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'
-            ),
-                          stacklevel=2)
-        if 'name' in kwargs:
-            del kwargs['name']
-        if 'loss_fn' in kwargs:
-            del kwargs['loss_fn']
-        if self.attn_config.get('alibi', False):
+            warnings.warn(
+                DeprecationWarning(
+                    "verbose argument for MPTConfig is now ignored and will be"
+                    " removed. Use python_log_level instead."),
+                stacklevel=2,
+            )
+        if "name" in kwargs:
+            del kwargs["name"]
+        if "loss_fn" in kwargs:
+            del kwargs["loss_fn"]
+        if self.attn_config.get("alibi", False):
             self.learned_pos_emb = False
             warnings.warn(
-                f'alibi is turned on, setting `learned_pos_emb` to {self.learned_pos_emb}`',
-                stacklevel=2)
+                "alibi is turned on, setting `learned_pos_emb` to "
+                f"{self.learned_pos_emb}`",
+                stacklevel=2,
+            )
         super().__init__(**kwargs)
         self._validate_config()
 
     def _set_config_defaults(
             self, config: Dict[str, Any],
             config_defaults: Dict[str, Any]) -> Dict[str, Any]:
-        for (k, v) in config_defaults.items():
+        for k, v in config_defaults.items():
             if k not in config:
                 config[k] = v
         return config
@@ -171,63 +122,69 @@ class MPTConfig(PretrainedConfig):
         self.init_config = self._set_config_defaults(self.init_config,
                                                      init_config_defaults)
         if self.d_model % self.n_heads != 0:
-            raise ValueError('d_model must be divisible by n_heads')
-        if any((
-                prob < 0 or prob > 1 for prob in
-            [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop]
-        )):
+            raise ValueError("d_model must be divisible by n_heads")
+        if any((prob < 0 or prob > 1 for prob in [
+                self.attn_config["attn_pdrop"],
+                self.resid_pdrop,
+                self.emb_pdrop,
+        ])):
             raise ValueError(
-                "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"  # pylint: disable=line-too-long
-            )
-        if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
+                "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop "
+                "are probabilities and must be between 0 and 1")
+        if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
             raise ValueError(
                 f"Unknown attn_impl={self.attn_config['attn_impl']}")
-        if self.attn_config['prefix_lm'] and self.attn_config[
-                'attn_impl'] not in ['torch', 'triton']:
+        if self.attn_config["prefix_lm"] and self.attn_config[
+                "attn_impl"] not in ["torch", "triton"]:
             raise NotImplementedError(
-                'prefix_lm only implemented with torch and triton attention.')
-        if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [
-                'torch', 'triton'
+                "prefix_lm only implemented with torch and triton attention.")
+        if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [
+                "torch",
+                "triton",
         ]:
             raise NotImplementedError(
-                'alibi only implemented with torch and triton attention.')
-        if self.attn_config['attn_uses_sequence_id'] and self.attn_config[
-                'attn_impl'] not in ['torch', 'triton']:
+                "alibi only implemented with torch and triton attention.")
+        if self.attn_config["attn_uses_sequence_id"] and self.attn_config[
+                "attn_impl"] not in ["torch", "triton"]:
             raise NotImplementedError(
-                'attn_uses_sequence_id only implemented with torch and triton attention.'  # pylint: disable=line-too-long
-            )
+                "attn_uses_sequence_id only implemented with torch and triton "
+                "attention.")
         if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
             raise ValueError(
-                'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'  # pylint: disable=line-too-long
-            )
-        if isinstance(self.logit_scale,
-                      str) and self.logit_scale != 'inv_sqrt_d_model':
+                "model.embedding_fraction must be between 0 (exclusive) and 1 "
+                "(inclusive)!")
+        if (isinstance(self.logit_scale, str)
+                and self.logit_scale != "inv_sqrt_d_model"):
             raise ValueError(
-                f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."  # pylint: disable=line-too-long
-            )
-        if self.init_config.get('name', None) is None:
+                f"self.logit_scale={self.logit_scale!r} is not recognized as "
+                "an option; use numeric value or 'inv_sqrt_d_model'.")
+        if self.init_config.get("name", None) is None:
             raise ValueError(
                 f"self.init_config={self.init_config!r} 'name' needs to be set."
             )
-        if not self.learned_pos_emb and (not self.attn_config['alibi']):
+        if not self.learned_pos_emb and (not self.attn_config["alibi"]):
             warnings.warn(
-                'Positional information not being provided to the model.',
-                stacklevel=2)
-        if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp':
+                "Positional information not being provided to the model.",
+                stacklevel=2,
+            )
+        if self.fc_type == "te" or self.ffn_config[  # codespell:ignore
+                "ffn_type"] == "te_ln_mlp":
             try:
                 # pylint: disable=import-outside-toplevel
                 import transformer_engine.pytorch as te
+
                 del te
             except Exception as exc:
                 raise ImportError(
                     # pylint: disable=line-too-long
-                    'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. '
-                    +
-                    'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n'
-                    + 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
-                    'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156'
+                    "TransformerEngine import fail. `fc_type: te` requires "
+                    "TransformerEngine be installed. " +
+                    "The required version of transformer_engine also "
+                    "requires FlashAttention v1.0.6 is installed:\n" +
+                    "pip install flash-attn==1.0.6 --no-build-isolation \n" +
+                    "pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156"
                 ) from exc
-        if self.ffn_config['ffn_type'] == 'mptmlp':
-            self.ffn_config['fc_type'] = self.fc_type
-        elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
-            self.ffn_config['bias'] = not self.no_bias
+        if self.ffn_config["ffn_type"] == "mptmlp":
+            self.ffn_config["fc_type"] = self.fc_type
+        elif self.ffn_config["ffn_type"] == "te_ln_mlp":
+            self.ffn_config["bias"] = not self.no_bias

+ 2 - 1
aphrodite/transformers_utils/configs/olmo.py

@@ -12,7 +12,8 @@ class OLMoConfig(PretrainedConfig):
         'num_hidden_layers': 'n_layers',
     }
 
-    # Note that the defaults for these attributes are equivalent to the base GPT2 model.
+    # Note that the defaults for these attributes are equivalent to the
+    # base GPT2 model.
     def __init__(
         self,
         d_model=768,

+ 22 - 11
aphrodite/transformers_utils/tokenizers/baichuan.py

@@ -50,7 +50,10 @@ class BaichuanTokenizer(PreTrainedTokenizer):
         clean_up_tokenization_spaces=False,
         **kwargs,
     ):
-        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+        if sp_model_kwargs is None:
+            self.sp_model_kwargs = {}
+        else:
+            self.sp_model_kwargs = sp_model_kwargs
         bos_token = (
             AddedToken(bos_token, lstrip=False, rstrip=False)
             if isinstance(bos_token, str)
@@ -105,7 +108,8 @@ class BaichuanTokenizer(PreTrainedTokenizer):
 
     def get_vocab(self):
         """Returns vocab as a dict"""
-        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+        vocab = {self.convert_ids_to_tokens(i): i for i in range(
+            self.vocab_size)}
         vocab.update(self.added_tokens_encoder)
         return vocab
 
@@ -128,7 +132,8 @@ class BaichuanTokenizer(PreTrainedTokenizer):
         out_string = ""
         prev_is_special = False
         for i, token in enumerate(tokens):
-            # make sure that special tokens are not decoded using sentencepiece model
+            # make sure that special tokens are not decoded using
+            # sentencepiece model
             if token in self.all_special_tokens:
                 if not prev_is_special and i != 0:
                     out_string += " "
@@ -155,7 +160,8 @@ class BaichuanTokenizer(PreTrainedTokenizer):
             `Tuple(str)`: Paths to the files saved.
         """
         if not os.path.isdir(save_directory):
-            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            logger.error(f"Vocabulary path ({save_directory}) should be"
+                         " a directory")
             return
         out_vocab_file = os.path.join(
             save_directory,
@@ -192,19 +198,22 @@ class BaichuanTokenizer(PreTrainedTokenizer):
         already_has_special_tokens: bool = False,
     ) -> List[int]:
         """
-        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
-        special tokens using the tokenizer `prepare_for_model` method.
+        Retrieve sequence ids from a token list that has no special tokens
+        added. This method is called when adding special tokens using the
+        tokenizer `prepare_for_model` method.
 
         Args:
             token_ids_0 (`List[int]`):
                 List of IDs.
             token_ids_1 (`List[int]`, *optional*):
                 Optional second list of IDs for sequence pairs.
-            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
-                Whether or not the token list is already formatted with special tokens for the model.
+            already_has_special_tokens(`bool`, *optional*, defaults to `False`):
+                Whether or not the token list is already formatted with special
+                tokens for the model.
 
         Returns:
-            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+            `List[int]`: A list of integers in the range [0, 1]: 1 for a
+            special token, 0 for a sequence token.
         """
         if already_has_special_tokens:
             return super().get_special_tokens_mask(
@@ -231,7 +240,8 @@ class BaichuanTokenizer(PreTrainedTokenizer):
         self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
     ) -> List[int]:
         """
-        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
+        Creates a mask from the two sequences passed to be used in a
+        sequence-pair classification task. An ALBERT
         sequence pair mask has the following format:
 
         ```
@@ -248,7 +258,8 @@ class BaichuanTokenizer(PreTrainedTokenizer):
                 Optional second list of IDs for sequence pairs.
 
         Returns:
-            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+            `List[int]`: List of [token type IDs](../glossary#token-type-ids)
+            according to the given sequence(s).
         """
         bos_token_id = [self.bos_token_id] if self.add_bos_token else []
         eos_token_id = [self.eos_token_id] if self.add_eos_token else []

+ 145 - 102
env.py

@@ -6,7 +6,6 @@ import sys
 import os
 from collections import namedtuple
 
-
 try:
     import torch
     TORCH_AVAILABLE = True
@@ -14,36 +13,38 @@ except (ImportError, NameError, AttributeError, OSError):
     TORCH_AVAILABLE = False
 
 # System Environment Information
-SystemEnv = namedtuple('SystemEnv', [
-    'torch_version',
-    'is_debug_build',
-    'cuda_compiled_version',
-    'gcc_version',
-    'clang_version',
-    'cmake_version',
-    'os',
-    'libc_version',
-    'python_version',
-    'python_platform',
-    'is_cuda_available',
-    'cuda_runtime_version',
-    'cuda_module_loading',
-    'nvidia_driver_version',
-    'nvidia_gpu_models',
-    'cudnn_version',
-    'pip_version',  # 'pip' or 'pip3'
-    'pip_packages',
-    'conda_packages',
-    'hip_compiled_version',
-    'hip_runtime_version',
-    'miopen_runtime_version',
-    'caching_allocator_config',
-    'is_xnnpack_available',
-    'cpu_info',
-    'rocm_version',
-    'aphrodite_version',
-    'aphrodite_build_flags',
-])
+SystemEnv = namedtuple(
+    'SystemEnv',
+    [
+        'torch_version',
+        'is_debug_build',
+        'cuda_compiled_version',
+        'gcc_version',
+        'clang_version',
+        'cmake_version',
+        'os',
+        'libc_version',
+        'python_version',
+        'python_platform',
+        'is_cuda_available',
+        'cuda_runtime_version',
+        'cuda_module_loading',
+        'nvidia_driver_version',
+        'nvidia_gpu_models',
+        'cudnn_version',
+        'pip_version',  # 'pip' or 'pip3'
+        'pip_packages',
+        'conda_packages',
+        'hip_compiled_version',
+        'hip_runtime_version',
+        'miopen_runtime_version',
+        'caching_allocator_config',
+        'is_xnnpack_available',
+        'cpu_info',
+        'rocm_version',
+        'aphrodite_version',
+        'aphrodite_build_flags',
+    ])
 
 DEFAULT_CONDA_PATTERNS = {
     "torch",
@@ -69,22 +70,23 @@ DEFAULT_PIP_PATTERNS = {
 
 def run(command):
     """Return (return-code, stdout, stderr)."""
-    shell = True if type(command) is str else False
-    p = subprocess.Popen(command, stdout=subprocess.PIPE,
-                         stderr=subprocess.PIPE, shell=shell)
+    shell = isinstance(command, str)
+    p = subprocess.Popen(command,
+                         stdout=subprocess.PIPE,
+                         stderr=subprocess.PIPE,
+                         shell=shell)
     raw_output, raw_err = p.communicate()
     rc = p.returncode
-    if get_platform() == 'win32':
-        enc = 'oem'
-    else:
-        enc = locale.getpreferredencoding()
+    enc = 'oem' if get_platform() == 'win32' else locale.getpreferredencoding()
     output = raw_output.decode(enc)
     err = raw_err.decode(enc)
     return rc, output.strip(), err.strip()
 
 
 def run_and_read_all(run_lambda, command):
-    """Run command using run_lambda; reads and returns entire output if rc is 0."""
+    """
+    Run command using run_lambda; reads and returns entire output if rc is 0.
+    """
     rc, out, _ = run_lambda(command)
     if rc != 0:
         return None
@@ -92,7 +94,9 @@ def run_and_read_all(run_lambda, command):
 
 
 def run_and_parse_first_match(run_lambda, command, regex):
-    """Run command using run_lambda, returns the first regex match if it exists."""
+    """
+    Run command using run_lambda, returns the first regex match if it exists.
+    """
     rc, out, _ = run_lambda(command)
     if rc != 0:
         return None
@@ -101,8 +105,11 @@ def run_and_parse_first_match(run_lambda, command, regex):
         return None
     return match.group(1)
 
+
 def run_and_return_first_line(run_lambda, command):
-    """Run command using run_lambda and returns first line if output is not empty."""
+    """
+    Run command using run_lambda and returns first line if output is not empty.
+    """
     rc, out, _ = run_lambda(command)
     if rc != 0:
         return None
@@ -117,22 +124,23 @@ def get_conda_packages(run_lambda, patterns=None):
     if out is None:
         return out
 
-    return "\n".join(
-        line
-        for line in out.splitlines()
-        if not line.startswith("#")
-        and any(name in line for name in patterns)
-    )
+    return "\n".join(line for line in out.splitlines()
+                     if not line.startswith("#") and any(name in line
+                                                         for name in patterns))
+
 
 def get_gcc_version(run_lambda):
     return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)')
 
+
 def get_clang_version(run_lambda):
-    return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)')
+    return run_and_parse_first_match(run_lambda, 'clang --version',
+                                     r'clang version (.*)')
 
 
 def get_cmake_version(run_lambda):
-    return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)')
+    return run_and_parse_first_match(run_lambda, 'cmake --version',
+                                     r'cmake (.*)')
 
 
 def get_nvidia_driver_version(run_lambda):
@@ -141,11 +149,13 @@ def get_nvidia_driver_version(run_lambda):
         return run_and_parse_first_match(run_lambda, cmd,
                                          r'com[.]nvidia[.]CUDA [(](.*?)[)]')
     smi = get_nvidia_smi()
-    return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ')
+    return run_and_parse_first_match(run_lambda, smi,
+                                     r'Driver Version: (.*?) ')
 
 
 def get_gpu_info(run_lambda):
-    if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None):
+    if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(
+            torch.version, 'hip') and torch.version.hip is not None):
         if TORCH_AVAILABLE and torch.cuda.is_available():
             if torch.version.hip is not None:
                 prop = torch.cuda.get_device_properties(0)
@@ -167,11 +177,14 @@ def get_gpu_info(run_lambda):
 
 
 def get_running_cuda_version(run_lambda):
-    return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)')
+    return run_and_parse_first_match(run_lambda, 'nvcc --version',
+                                     r'release .+ V(.*)')
 
 
 def get_cudnn_version(run_lambda):
-    """Return a list of libcudnn.so; it's hard to tell which one is being used."""
+    """
+    Return a list of libcudnn.so; it's hard to tell which one is being used.
+    """
     if get_platform() == 'win32':
         system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
         cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%")
@@ -188,9 +201,9 @@ def get_cudnn_version(run_lambda):
     rc, out, _ = run_lambda(cudnn_cmd)
     # find will return 1 if there are permission errors or if not found
     if len(out) == 0 or (rc != 1 and rc != 0):
-        l = os.environ.get('CUDNN_LIBRARY')
-        if l is not None and os.path.isfile(l):
-            return os.path.realpath(l)
+        length = os.environ.get('CUDNN_LIBRARY')
+        if length is not None and os.path.isfile(length):
+            return os.path.realpath(length)
         return None
     files_set = set()
     for fn in out.split('\n'):
@@ -212,8 +225,10 @@ def get_nvidia_smi():
     smi = 'nvidia-smi'
     if get_platform() == 'win32':
         system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
-        program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files')
-        legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi)
+        program_files_root = os.environ.get('PROGRAMFILES',
+                                            'C:\\Program Files')
+        legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation',
+                                   'NVSMI', smi)
         new_path = os.path.join(system_root, 'System32', smi)
         smis = [new_path, legacy_path]
         for candidate_smi in smis:
@@ -225,8 +240,8 @@ def get_nvidia_smi():
 
 def get_rocm_version(run_lambda):
     """Returns the ROCm version if available, otherwise 'N/A'."""
-    return run_and_parse_first_match(run_lambda, 'hipcc --version', r'HIP version: (\S+)')
-
+    return run_and_parse_first_match(run_lambda, 'hipcc --version',
+                                     r'HIP version: (\S+)')
 
 
 def get_aphrodite_version():
@@ -238,7 +253,8 @@ def get_aphrodite_version():
 
 
 def summarize_aphrodite_build_flags():
-    # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
+    # This could be a static method if the flags are constant, or dynamic if
+    # you need to check environment variables, etc.
     return 'CUDA Archs: {}; ROCm: {}'.format(
         os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'),
         'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled',
@@ -250,15 +266,13 @@ def get_cpu_info(run_lambda):
     if get_platform() == 'linux':
         rc, out, err = run_lambda('lscpu')
     elif get_platform() == 'win32':
-        rc, out, err = run_lambda('wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
-        CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE')
+        rc, out, err = run_lambda(
+            'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType, \
+                DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize, \
+                    L2CacheSpeed,Revision /VALUE')
     elif get_platform() == 'darwin':
         rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
-    cpu_info = 'None'
-    if rc == 0:
-        cpu_info = out
-    else:
-        cpu_info = err
+    cpu_info = out if rc == 0 else err
     return cpu_info
 
 
@@ -276,18 +290,22 @@ def get_platform():
 
 
 def get_mac_version(run_lambda):
-    return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)')
+    return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion',
+                                     r'(.*)')
 
 
 def get_windows_version(run_lambda):
     system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
     wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic')
     findstr_cmd = os.path.join(system_root, 'System32', 'findstr')
-    return run_and_read_all(run_lambda, '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd))
+    return run_and_read_all(
+        run_lambda,
+        '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd))
 
 
 def get_lsb_version(run_lambda):
-    return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)')
+    return run_and_parse_first_match(run_lambda, 'lsb_release -a',
+                                     r'Description:\t(.*)')
 
 
 def check_release_file(run_lambda):
@@ -338,7 +356,10 @@ def get_libc_version():
 
 
 def get_pip_packages(run_lambda, patterns=None):
-    """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages."""
+    """
+    Return `pip list` output. Note: will also find conda-installed pytorch and
+    numpy packages.
+    """
     if patterns is None:
         patterns = DEFAULT_PIP_PATTERNS
 
@@ -346,11 +367,8 @@ def get_pip_packages(run_lambda, patterns=None):
     # But here it is invoked as `python -mpip`
     def run_with_pip(pip):
         out = run_and_read_all(run_lambda, pip + ["list", "--format=freeze"])
-        return "\n".join(
-            line
-            for line in out.splitlines()
-            if any(name in line for name in patterns)
-        )
+        return "\n".join(line for line in out.splitlines()
+                         if any(name in line for name in patterns))
 
     pip_version = 'pip3' if sys.version[0] == '3' else 'pip'
     out = run_with_pip([sys.executable, '-mpip'])
@@ -375,10 +393,12 @@ def get_cuda_module_loading_config():
 def is_xnnpack_available():
     if TORCH_AVAILABLE:
         import torch.backends.xnnpack
-        return str(torch.backends.xnnpack.enabled)  # type: ignore[attr-defined]
+        return str(
+            torch.backends.xnnpack.enabled)  # type: ignore[attr-defined]
     else:
         return "N/A"
 
+
 def get_env_info():
     run_lambda = run
     pip_version, pip_list_output = get_pip_packages(run_lambda)
@@ -388,9 +408,13 @@ def get_env_info():
         debug_mode_str = str(torch.version.debug)
         cuda_available_str = str(torch.cuda.is_available())
         cuda_version_str = torch.version.cuda
-        if not hasattr(torch.version, 'hip') or torch.version.hip is None:  # cuda version
-            hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
+        if not hasattr(torch.version,
+                       'hip') or torch.version.hip is None:  # cuda version
+            hip_compiled_version = 'N/A'
+            hip_runtime_version = 'N/A'
+            miopen_runtime_version = 'N/A'
         else:  # HIP version
+
             def get_version_or_na(cfg, prefix):
                 _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
                 return _lst[0] if _lst else 'N/A'
@@ -401,8 +425,13 @@ def get_env_info():
             cuda_version_str = 'N/A'
             hip_compiled_version = torch.version.hip
     else:
-        version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A'
-        hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
+        version_str = 'N/A'
+        debug_mode_str = 'N/A'
+        cuda_available_str = 'N/A'
+        cuda_version_str = 'N/A'
+        hip_compiled_version = 'N/A'
+        hip_runtime_version = 'N/A'
+        miopen_runtime_version = 'N/A'
 
     sys_version = sys.version.replace("\n", " ")
 
@@ -415,7 +444,9 @@ def get_env_info():
     return SystemEnv(
         torch_version=version_str,
         is_debug_build=debug_mode_str,
-        python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1),
+        python_version='{} ({}-bit runtime)'.format(
+            sys_version,
+            sys.maxsize.bit_length() + 1),
         python_platform=get_python_platform(),
         is_cuda_available=cuda_available_str,
         cuda_compiled_version=cuda_version_str,
@@ -443,6 +474,7 @@ def get_env_info():
         aphrodite_build_flags=aphrodite_build_flags,
     )
 
+
 env_info_fmt = """
 PyTorch version: {torch_version}
 Is debug build: {is_debug_build}
@@ -480,15 +512,16 @@ Aphrodite Build Flags:
 
 
 def pretty_str(envinfo):
+
     def replace_nones(dct, replacement='Could not collect '):
-        for key in dct.keys():
+        for key in dct:
             if dct[key] is not None:
                 continue
             dct[key] = replacement
         return dct
 
     def replace_bools(dct, true='Yes', false='No'):
-        for key in dct.keys():
+        for key in dct:
             if dct[key] is True:
                 dct[key] = true
             elif dct[key] is False:
@@ -524,9 +557,10 @@ def pretty_str(envinfo):
         'nvidia_driver_version',
     ]
     all_cuda_fields = dynamic_cuda_fields + ['cudnn_version']
-    all_dynamic_cuda_fields_missing = all(
-        mutable_dict[field] is None for field in dynamic_cuda_fields)
-    if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing:
+    all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None
+                                          for field in dynamic_cuda_fields)
+    if TORCH_AVAILABLE and not torch.cuda.is_available(
+    ) and all_dynamic_cuda_fields_missing:
         for field in all_cuda_fields:
             mutable_dict[field] = 'No CUDA'
         if envinfo.cuda_compiled_version is None:
@@ -539,17 +573,20 @@ def pretty_str(envinfo):
     mutable_dict = replace_nones(mutable_dict)
 
     # If either of these are '', replace with 'No relevant packages'
-    mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages'])
-    mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages'])
+    mutable_dict['pip_packages'] = replace_if_empty(
+        mutable_dict['pip_packages'])
+    mutable_dict['conda_packages'] = replace_if_empty(
+        mutable_dict['conda_packages'])
 
     # Tag conda and pip packages with a prefix
-    # If they were previously None, they'll show up as ie '[conda] Could not collect'
+    # If they were previously None, they'll show up as ie '[conda] Could not
+    # collect'
     if mutable_dict['pip_packages']:
-        mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'],
-                                               '[{}] '.format(envinfo.pip_version))
+        mutable_dict['pip_packages'] = prepend(
+            mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version))
     if mutable_dict['conda_packages']:
-        mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'],
-                                                 '[conda] ')
+        mutable_dict['conda_packages'] = prepend(
+            mutable_dict['conda_packages'], '[conda] ')
     mutable_dict['cpu_info'] = envinfo.cpu_info
     return env_info_fmt.format(**mutable_dict)
 
@@ -563,18 +600,24 @@ def main():
     output = get_pretty_env_info()
     print(output)
 
-    if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'):
+    if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(
+            torch.utils, '_crash_handler'):
         minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
         if sys.platform == "linux" and os.path.exists(minidump_dir):
-            dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)]
+            dumps = [
+                os.path.join(minidump_dir, dump)
+                for dump in os.listdir(minidump_dir)
+            ]
             latest = max(dumps, key=os.path.getctime)
             ctime = os.path.getctime(latest)
-            creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S')
-            msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \
-                  "if this is related to your bug please include it when you file a report ***"
+            creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
+                '%Y-%m-%d %H:%M:%S')
+            msg = "\n*** Detected a minidump at {} created on {}, ".format( \
+                latest, creation_time) + \
+                  "if this is related to your bug please include it when you " \
+                    "file a report ***"
             print(msg, file=sys.stderr)
 
 
-
 if __name__ == '__main__':
-    main()
+    main()

+ 89 - 39
examples/gguf_to_torch.py

@@ -6,15 +6,19 @@ import gguf
 from sentencepiece import sentencepiece_model_pb2
 from safetensors.torch import save_file as safe_save_file
 from transformers.modeling_utils import shard_checkpoint
-from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
+from transformers.utils import (WEIGHTS_NAME, WEIGHTS_INDEX_NAME,
+                                SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME)
 
-def convert_to_state_dict(checkpoint, save_dir, max_shard_size, safe_serialization):
+
+def convert_to_state_dict(checkpoint, save_dir, max_shard_size,
+                          safe_serialization):
     if not os.path.exists(save_dir):
         os.makedirs(save_dir)
     state_dict = {}
     result = gguf.GGUFReader(checkpoint)
     architecture = result.fields['general.architecture']
-    architecture = str(bytes(architecture.parts[architecture.data[0]]), encoding = 'utf-8')
+    architecture = str(bytes(architecture.parts[architecture.data[0]]),
+                       encoding='utf-8')
     if architecture != "llama":
         print(f"Unsupported architecture {architecture}")
         return
@@ -22,7 +26,7 @@ def convert_to_state_dict(checkpoint, save_dir, max_shard_size, safe_serializati
     # write vocab
     vocab = sentencepiece_model_pb2.ModelProto()
     vocab_size = len(result.fields['tokenizer.ggml.token_type'].data)
-    vocab.trainer_spec.model_type = 2 # BPE
+    vocab.trainer_spec.model_type = 2  # BPE
     vocab.trainer_spec.vocab_size = vocab_size
     vocab.trainer_spec.byte_fallback = True
     vocab.normalizer_spec.remove_extra_whitespaces = False
@@ -31,7 +35,8 @@ def convert_to_state_dict(checkpoint, save_dir, max_shard_size, safe_serializati
     types = result.fields['tokenizer.ggml.token_type']
     for i in range(vocab_size):
         new_token = vocab.SentencePiece()
-        new_token.piece = str(bytes(tokens.parts[tokens.data[i]]), encoding = 'utf-8')
+        new_token.piece = str(bytes(tokens.parts[tokens.data[i]]),
+                              encoding='utf-8')
         new_token.score = scores.parts[scores.data[i]]
         # llama.cpp tokentype is the same with sentencepiece token type
         new_token.type = int(types.parts[types.data[i]])
@@ -44,28 +49,40 @@ def convert_to_state_dict(checkpoint, save_dir, max_shard_size, safe_serializati
         "clean_up_tokenization_spaces": False,
     }
     if 'tokenizer.ggml.bos_token_id' in result.fields:
-        tokenizer_config["bos_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.bos_token_id'].parts[-1])].piece
+        tokenizer_config["bos_token"] = vocab.pieces[int(
+            result.fields['tokenizer.ggml.bos_token_id'].parts[-1])].piece
     if 'tokenizer.ggml.eos_token_id' in result.fields:
-        tokenizer_config["eos_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.eos_token_id'].parts[-1])].piece
+        tokenizer_config["eos_token"] = vocab.pieces[int(
+            result.fields['tokenizer.ggml.eos_token_id'].parts[-1])].piece
     if 'tokenizer.ggml.padding_token_id' in result.fields:
-        tokenizer_config["pad_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.padding_token_id'].parts[-1])].piece
+        tokenizer_config["pad_token"] = vocab.pieces[int(
+            result.fields['tokenizer.ggml.padding_token_id'].parts[-1])].piece
     if 'tokenizer.ggml.unknown_token_id' in result.fields:
-        tokenizer_config["unk_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.unknown_token_id'].parts[-1])].piece
+        tokenizer_config["unk_token"] = vocab.pieces[int(
+            result.fields['tokenizer.ggml.unknown_token_id'].parts[-1])].piece
     if 'tokenizer.ggml.add_bos_token' in result.fields:
-        tokenizer_config["add_bos_token"] = bool(result.fields['tokenizer.ggml.add_bos_token'].parts[-1])
+        tokenizer_config["add_bos_token"] = bool(
+            result.fields['tokenizer.ggml.add_bos_token'].parts[-1])
     if 'tokenizer.ggml.add_eos_token' in result.fields:
-        tokenizer_config["add_eos_token"] = bool(result.fields['tokenizer.ggml.add_eos_token'].parts[-1])
+        tokenizer_config["add_eos_token"] = bool(
+            result.fields['tokenizer.ggml.add_eos_token'].parts[-1])
     if 'tokenizer.chat_template' in result.fields:
-        tokenizer_config["chat_template"] = str(bytes(result.fields['tokenizer.chat_template'].parts[-1]), encoding="utf-8")
-    json.dump(tokenizer_config, open(os.path.join(save_dir, "tokenizer_config.json"), 'w'), indent=2)
+        tokenizer_config["chat_template"] = str(bytes(
+            result.fields['tokenizer.chat_template'].parts[-1]),
+                                                encoding="utf-8")
+    with open(os.path.join(save_dir, "tokenizer_config.json"), 'w') as f:
+        json.dump(tokenizer_config, f, indent=2)
 
     # write config
     context_length = int(result.fields['llama.context_length'].parts[-1])
     n_layer = int(result.fields['llama.block_count'].parts[-1])
     n_head = int(result.fields['llama.attention.head_count'].parts[-1])
-    n_local_heads = int(result.fields['llama.attention.head_count_kv'].parts[-1])
-    intermediate_size = int(result.fields['llama.feed_forward_length'].parts[-1])
-    norm_eps = float(result.fields['llama.attention.layer_norm_rms_epsilon'].parts[-1])
+    n_local_heads = int(
+        result.fields['llama.attention.head_count_kv'].parts[-1])
+    intermediate_size = int(
+        result.fields['llama.feed_forward_length'].parts[-1])
+    norm_eps = float(
+        result.fields['llama.attention.layer_norm_rms_epsilon'].parts[-1])
     dim = int(result.fields['llama.embedding_length'].parts[-1])
     kv_dim = dim // n_head * n_local_heads
     arch = "MixtralForCausalLM"
@@ -75,7 +92,7 @@ def convert_to_state_dict(checkpoint, save_dir, max_shard_size, safe_serializati
     else:
         arch = "LlamaForCausalLM"
         name = "llama"
-    model_config= {
+    model_config = {
         "architectures": [arch],
         "bos_token_id": 1,
         "eos_token_id": 2,
@@ -92,11 +109,15 @@ def convert_to_state_dict(checkpoint, save_dir, max_shard_size, safe_serializati
         "vocab_size": vocab_size
     }
     if 'llama.rope.freq_base' in result.fields:
-        model_config['rope_theta'] = float(result.fields['llama.rope.freq_base'].parts[-1])
+        model_config['rope_theta'] = float(
+            result.fields['llama.rope.freq_base'].parts[-1])
     if 'llama.expert_count' in result.fields:
-        model_config['num_local_experts'] = int(result.fields['llama.expert_count'].parts[-1])
-        model_config['num_experts_per_tok'] = int(result.fields['llama.expert_used_count'].parts[-1])
-    json.dump(model_config, open(os.path.join(save_dir, "config.json"), 'w'), indent=2)
+        model_config['num_local_experts'] = int(
+            result.fields['llama.expert_count'].parts[-1])
+        model_config['num_experts_per_tok'] = int(
+            result.fields['llama.expert_used_count'].parts[-1])
+    with open(os.path.join(save_dir, "config.json"), 'w') as f:
+        json.dump(model_config, f, indent=2)
 
     # write tensor
     tensor_mapping = {
@@ -108,15 +129,25 @@ def convert_to_state_dict(checkpoint, save_dir, max_shard_size, safe_serializati
         "blk.{bid}.attn_k": ("model.layers.{bid}.self_attn.k_proj", kv_dim),
         "blk.{bid}.attn_v": ("model.layers.{bid}.self_attn.v_proj", kv_dim),
         "blk.{bid}.attn_output": ("model.layers.{bid}.self_attn.o_proj", dim),
-        "blk.{bid}.attn_rot_embd": ("model.layers.{bid}.self_attn.rotary_emb.inv_freq", -1),
-        "blk.{bid}.ffn_norm": ("model.layers.{bid}.post_attention_layernorm", -1),
-        "blk.{bid}.ffn_up": ("model.layers.{bid}.mlp.up_proj", intermediate_size),
+        "blk.{bid}.attn_rot_embd":
+        ("model.layers.{bid}.self_attn.rotary_emb.inv_freq", -1),
+        "blk.{bid}.ffn_norm": ("model.layers.{bid}.post_attention_layernorm",
+                               -1),
+        "blk.{bid}.ffn_up": ("model.layers.{bid}.mlp.up_proj",
+                             intermediate_size),
         "blk.{bid}.ffn_down": ("model.layers.{bid}.mlp.down_proj", dim),
-        "blk.{bid}.ffn_gate": ("model.layers.{bid}.mlp.gate_proj", intermediate_size),
-        "blk.{bid}.ffn_up.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", intermediate_size),
-        "blk.{bid}.ffn_down.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", dim),
-        "blk.{bid}.ffn_gate.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", intermediate_size),
-        "blk.{bid}.ffn_gate_inp": ("model.layers.{bid}.block_sparse_moe.gate", model_config.get('num_local_experts', 1)),
+        "blk.{bid}.ffn_gate": ("model.layers.{bid}.mlp.gate_proj",
+                               intermediate_size),
+        "blk.{bid}.ffn_up.{xid}":
+        ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w3",
+         intermediate_size),
+        "blk.{bid}.ffn_down.{xid}":
+        ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", dim),
+        "blk.{bid}.ffn_gate.{xid}":
+        ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w1",
+         intermediate_size),
+        "blk.{bid}.ffn_gate_inp": ("model.layers.{bid}.block_sparse_moe.gate",
+                                   model_config.get('num_local_experts', 1)),
     }
     mapping = {}
     max_block_num = 200
@@ -142,18 +173,27 @@ def convert_to_state_dict(checkpoint, save_dir, max_shard_size, safe_serializati
         state_dict[new_key] = data
     if max_shard_size == "0":
         if safe_serialization:
-            safe_save_file(state_dict, os.path.join(save_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
+            safe_save_file(state_dict,
+                           os.path.join(save_dir, SAFE_WEIGHTS_NAME),
+                           metadata={"format": "pt"})
         else:
             torch.save(state_dict, os.path.join(save_dir, WEIGHTS_NAME))
     else:
-        shards, index = shard_checkpoint(state_dict, max_shard_size, SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME)
+        shards, index = shard_checkpoint(
+            state_dict, max_shard_size,
+            SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME)
         for shard_file, shard in shards.items():
             if safe_serialization:
-                safe_save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
+                safe_save_file(shard,
+                               os.path.join(save_dir, shard_file),
+                               metadata={"format": "pt"})
             else:
                 torch.save(shard, os.path.join(save_dir, shard_file))
         if index is not None:
-            save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
+            if safe_serialization:
+                save_index_file = SAFE_WEIGHTS_INDEX_NAME
+            else:
+                save_index_file = WEIGHTS_INDEX_NAME
             save_index_file = os.path.join(save_dir, save_index_file)
             # Save the index as well
             with open(save_index_file, "w", encoding="utf-8") as f:
@@ -161,14 +201,24 @@ def convert_to_state_dict(checkpoint, save_dir, max_shard_size, safe_serializati
                 f.write(content)
 
 
-
 if __name__ == '__main__':
     import argparse
-    parser = argparse.ArgumentParser(description='Convert GGUF checkpoints to torch')
+    parser = argparse.ArgumentParser(
+        description='Convert GGUF checkpoints to torch')
 
     parser.add_argument('--input', type=str, help='The path to GGUF file')
-    parser.add_argument('--output', type=str, help='The path to output directory')
-    parser.add_argument('--max-shard-size', default="0", type=str, help='Shard the model in specified shard size, e.g. 10GB. 0 to disable')
-    parser.add_argument('--safetensors', action='store_true', help='Save in .safetensors format')
+    parser.add_argument('--output',
+                        type=str,
+                        help='The path to output directory')
+    parser.add_argument(
+        '--max-shard-size',
+        default="0",
+        type=str,
+        help='Shard the model in specified shard size, e.g. 10GB. 0 to disable'
+    )
+    parser.add_argument('--safetensors',
+                        action='store_true',
+                        help='Save in .safetensors format')
     args = parser.parse_args()
-    convert_to_state_dict(args.input, args.output, args.max_shard_size, args.safetensors)
+    convert_to_state_dict(args.input, args.output, args.max_shard_size,
+                          args.safetensors)

+ 19 - 10
examples/gradio_server.py

@@ -12,33 +12,42 @@ def http_bot(prompt):
         "stream": True,
         "max_tokens": 512,
     }
-    response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
-
-    for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
+    response = requests.post(args.model_url,
+                             headers=headers,
+                             json=pload,
+                             stream=True)
+
+    for chunk in response.iter_lines(chunk_size=8192,
+                                     decode_unicode=False,
+                                     delimiter=b"\0"):
         if chunk:
             data = json.loads(chunk.decode("utf-8"))
             output = data["text"][0]
             yield output
 
+
 def build_demo():
     with gr.Blocks() as demo:
-        gr.Markdown(
-            "# Aphrodite text completion demo\n"
-        )
-        inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
-        outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
+        gr.Markdown("# Aphrodite text completion demo\n")
+        inputbox = gr.Textbox(label="Input",
+                              placeholder="Enter text and press ENTER")
+        outputbox = gr.Textbox(label="Output",
+                               placeholder="Generated result from the model")
         inputbox.submit(http_bot, [inputbox], [outputbox])
     return demo
 
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("--host", type=str, default="localhost")
     parser.add_argument("--port", type=int, default=8001)
-    parser.add_argument("--model-url", type=str, default="http://localhost:2242/api/v1/generate")
+    parser.add_argument("--model-url",
+                        type=str,
+                        default="http://localhost:2242/api/v1/generate")
 
     args = parser.parse_args()
 
     demo = build_demo()
     demo.queue(concurrency_count=100).launch(server_name=args.host,
                                              server_port=args.port,
-                                             share=True)    
+                                             share=True)

+ 42 - 12
examples/marlin/convert.py

@@ -1,4 +1,6 @@
-import torch, argparse, copy
+import torch
+import argparse
+import copy
 from transformers import AutoModelForCausalLM, AutoTokenizer
 from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
 from marlin import Layer as MarlinLayer
@@ -9,26 +11,47 @@ parser.add_argument("--model-id", type=str)
 parser.add_argument("--save-path", type=str)
 parser.add_argument("--do-generation", action="store_true")
 
+
 def _validate_compatibility(model):
     if not hasattr(model.config, "quantization_config"):
-        raise ValueError("Must be a quantized model to convert to Marlin Format")
+        raise ValueError(
+            "Must be a quantized model to convert to Marlin Format")
     quantization_config = model.config.quantization_config
     if quantization_config.quant_method != "gptq":
-        raise ValueError(f"Only GPTQ models can be converted to Marlin format. You passed a model with quant_method={quantization_config.quant_method}")
+        raise ValueError(
+            "Only GPTQ models can be converted to Marlin format. You passed a "
+            f"model with quant_method={quantization_config.quant_method}")
     if quantization_config.bits != 4:
-        raise ValueError(f"Only 4 bit quantized models can be converted to Marlin format. You passed a model with bits={quantization_config.bits}")
+        raise ValueError(
+            "Only 4 bit quantized models can be converted to Marlin format. "
+            f"You passed a model with bits={quantization_config.bits}")
     if quantization_config.group_size != 128:
-        raise ValueError(f"Only group size 128 models can be converted to Marlin format. You passed a model with group_size={quantization_config.group_size}")
+        raise ValueError(
+            "Only group size 128 models can be converted to Marlin format. You "
+            f"passed a model with group_size={quantization_config.group_size}")
     if not quantization_config.sym:
-        raise ValueError(f"Only models with symmetric quantization can be converted to Marlin Format. You passed a model with sym={quantization_config.sym}")
+        raise ValueError(
+            "Only models with symmetric quantization can be converted to "
+            "Marlin Format. You passed a model with sym="
+            f"{quantization_config.sym}")
     if quantization_config.desc_act:
-        raise ValueError(f"Models with act order quantization cannot be converted to Marlin Format. You passed a model with desc_act={quantization_config.desc_act}")
+        raise ValueError(
+            "Models with act order quantization cannot be converted to "
+            "Marlin Format. You passed a model with desc_act="
+            f"{quantization_config.desc_act}")
+
 
 @torch.no_grad()
 def unpack_4bit_to_32bit_signed(qweight, qzeros):
     # Unpack 4-bit values and interpret them as signed integers
-    unpacked_weights = torch.zeros((qweight.shape[0]*8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, requires_grad=False)
-    unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*8), dtype=torch.int8, device=qzeros.device, requires_grad=False)
+    unpacked_weights = torch.zeros((qweight.shape[0] * 8, qweight.shape[1]),
+                                   dtype=torch.int8,
+                                   device=qweight.device,
+                                   requires_grad=False)
+    unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1] * 8),
+                                 dtype=torch.int8,
+                                 device=qzeros.device,
+                                 requires_grad=False)
 
     for row in range(unpacked_weights.shape[0]):
         i = row % 8
@@ -40,10 +63,12 @@ def unpack_4bit_to_32bit_signed(qweight, qzeros):
 
     return unpacked_weights, unpacked_zeros + 1
 
+
 @torch.no_grad()
 def dequantize_weight(layer):
     qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
-    unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros)
+    unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(
+        qweight, qzeros)
     group_size = unpacked_qweight.shape[0] // scales.shape[0]
     scales = scales.repeat_interleave(group_size, dim=0)
     unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
@@ -51,6 +76,7 @@ def dequantize_weight(layer):
 
     return unpacked_qweight.T
 
+
 @torch.no_grad()
 def convert_model(model, verbose=True):
     for name, module in model.named_modules():
@@ -77,7 +103,8 @@ def convert_model(model, verbose=True):
             infeatures=linear_module.in_features,
             outfeatures=linear_module.out_features,
             groupsize=model.config.quantization_config.group_size)
-        new_module.pack(linear_module, scales=copy.deepcopy(module.scales.data.t()))
+        new_module.pack(linear_module,
+                        scales=copy.deepcopy(module.scales.data.t()))
 
         # Save to parent.
         parent_module = model.get_submodule(parent_name)
@@ -90,6 +117,7 @@ def convert_model(model, verbose=True):
 
     return model
 
+
 @torch.no_grad()
 def dequantize_model(model, verbose=True):
     for name, module in model.named_modules():
@@ -112,7 +140,8 @@ def dequantize_model(model, verbose=True):
             bias=False,
             dtype=torch.float16)
         new_module.weight.data.copy_(dequantized_weight_cpu)
-        new_module.scales = torch.nn.Parameter(copy.deepcopy(module.scales.data))
+        new_module.scales = torch.nn.Parameter(
+            copy.deepcopy(module.scales.data))
 
         # Save to parent.
         parent_module = model.get_submodule(parent_name)
@@ -124,6 +153,7 @@ def dequantize_model(model, verbose=True):
 
     return model
 
+
 if __name__ == "__main__":
     args = parser.parse_args()
     model_id = args.model_id

+ 4 - 3
examples/offline_inference.py

@@ -3,7 +3,7 @@ from aphrodite import LLM, SamplingParams
 # Sample prompts.
 prompts = [
     "<|system|>Enter chat mode.<|user|>Hello!<|model|>",
-    "<|system|>Enter RP mode.<|model|>Hello!<|user|>What are you doing?<|model|>",
+    "<|system|>Enter RP mode.<|model|>Hello!<|user|>What are you doing?",
     "<|system|>Enter chat mode.<|user|>What is the meaning of life?<|model|>",
     "<|system|>Enter QA mode.<|user|>What is a man?<|model|>A miserable",
 ]
@@ -11,7 +11,8 @@ prompts = [
 sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
 
 # Create an LLM.
-llm = LLM(model="PygmalionAI/pygmalion-2-7b") # pass additional arguments here, such as `quantization`
+llm = LLM(model="PygmalionAI/pygmalion-2-7b"
+          )  # pass additional arguments here, such as `quantization`
 # Generate texts from the prompts. The output is a list of RequestOutput objects
 # that contain the prompt, generated text, and other information.
 outputs = llm.generate(prompts, sampling_params)
@@ -19,4 +20,4 @@ outputs = llm.generate(prompts, sampling_params)
 for output in outputs:
     prompt = output.prompt
     generated_text = output.outputs[0].text
-    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

+ 44 - 24
examples/slora_inference.py

@@ -1,6 +1,6 @@
 """
-This example shows how to use the multi-LoRA functionality for offline inference.
-Requires HuggingFace credentials for access to Llama2.
+This example shows how to use the multi-LoRA functionality for offline
+inference. Requires HuggingFace credentials for access to Llama2.
 """
 
 from typing import Optional, List, Tuple
@@ -21,45 +21,65 @@ def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
     first adapter have finished.
     """
     return [
-        ("A robot may not injure a human being",
-         SamplingParams(temperature=0.0,
-                        # logprobs=1,
-                        prompt_logprobs=1,
-                        max_tokens=128), None),
+        (
+            "A robot may not injure a human being",
+            SamplingParams(
+                temperature=0.0,
+                # logprobs=1,
+                prompt_logprobs=1,
+                max_tokens=128),
+            None),
         ("To be or not to be,",
          SamplingParams(temperature=0.8,
                         top_k=5,
                         presence_penalty=0.2,
                         max_tokens=128), None),
-        ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
-         SamplingParams(temperature=0.0,
-                        # logprobs=1,
-                        prompt_logprobs=1,
-                        max_tokens=128,
-                        stop_token_ids=[32003]),
-         LoRARequest("l2-lora-test", 1, lora_path)),
-        ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
+        (
+            """[user] Write a SQL query to answer the question based on the
+            table schema.\n\n context: CREATE TABLE table_name_74
+            (icao VARCHAR, airport VARCHAR)\n\n
+            question: Name the ICAO for lilongwe
+            international airport [/user] [assistant]""",
+            SamplingParams(
+                temperature=0.0,
+                # logprobs=1,
+                prompt_logprobs=1,
+                max_tokens=128,
+                stop_token_ids=[32003]),
+            LoRARequest("l2-lora-test", 1, lora_path)),
+        ("""[user] Write a SQL query to answer the question based on the table
+         schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR,
+         elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector
+         what is under nationality? [/user] [assistant]""",
          SamplingParams(n=3,
                         best_of=3,
                         temperature=0.8,
                         max_tokens=128,
                         stop_token_ids=[32003]),
          LoRARequest("l2-lora-test", 1, lora_path)),
-        ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
-         SamplingParams(temperature=0.0,
-                        # logprobs=1,
-                        prompt_logprobs=1,
-                        max_tokens=128,
-                        stop_token_ids=[32003]),
-         LoRARequest("l2-lora-test2", 2, lora_path)),
-        ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
+        (
+            """[user] Write a SQL query to answer the question based on the
+            table schema.\n\n context: CREATE TABLE table_name_74 (icao
+            VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe
+            international airport [/user] [assistant]""",
+            SamplingParams(
+                temperature=0.0,
+                # logprobs=1,
+                prompt_logprobs=1,
+                max_tokens=128,
+                stop_token_ids=[32003]),
+            LoRARequest("l2-lora-test2", 2, lora_path)),
+        ("""[user] Write a SQL query to answer the question based on the table
+         schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR,
+         elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector
+         what is under nationality? [/user] [assistant]""",
          SamplingParams(n=3,
                         best_of=3,
                         temperature=0.9,
                         max_tokens=128,
                         stop_token_ids=[32003]),
          LoRARequest("l2-lora-test", 1, lora_path)),
-    ] # type: ignore
+    ]  # type: ignore
 
 
 def process_requests(engine: AphroditeEngine,

+ 89 - 9
formatting.sh

@@ -5,9 +5,9 @@
 #    # Do work and commit your work.
 
 #    # Format files that differ from origin/main.
-#    bash format.sh
+#    bash formatting.sh
 
-#    # Commit changed files with message 'Run yapf and pylint'
+#    # Commit changed files with message 'Run yapf and ruff'
 #
 #
 # YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
@@ -22,8 +22,9 @@ ROOT="$(git rev-parse --show-toplevel)"
 builtin cd "$ROOT" || exit 1
 
 YAPF_VERSION=$(yapf --version | awk '{print $2}')
-PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
+RUFF_VERSION=$(ruff --version | awk '{print $2}')
 MYPY_VERSION=$(mypy --version | awk '{print $2}')
+CODESPELL_VERSION=$(codespell --version)
 
 # # params: tool name, tool version, required version
 tool_version_check() {
@@ -34,8 +35,9 @@ tool_version_check() {
 }
 
 tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
-tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
+tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)"
 tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
+tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
 
 YAPF_FLAGS=(
     '--recursive'
@@ -71,7 +73,7 @@ format_changed() {
 
 # Format all files
 format_all() {
-    yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" aphrodite tests
+    yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" .
 }
 
 ## This flag formats individual files. --files *must* be the first command line
@@ -86,12 +88,90 @@ else
    # Format only the files that changed in last commit.
    format_changed
 fi
-echo 'Aphrodite Engine yapf: Done'
+echo 'Aphrodite yapf: Done'
 
+CODESPELL_EXCLUDES=(
+    '--skip' '*docs/source/_build/**'
+)
+
+# check spelling of specified files
+spell_check() {
+    codespell "$@"
+}
+
+spell_check_all(){
+  codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}"
+}
+
+# Spelling  check of files that differ from main branch.
+spell_check_changed() {
+    # The `if` guard ensures that the list of filenames is not empty, which
+    # could cause ruff to receive 0 positional arguments, making it hang
+    # waiting for STDIN.
+    #
+    # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
+    # exist on both branches.
+    MERGEBASE="$(git merge-base origin/main HEAD)"
+
+    if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
+        git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
+             codespell "${CODESPELL_EXCLUDES[@]}"
+    fi
+}
+
+# Run Codespell
+## This flag runs spell check of individual files. --files *must* be the first command line
+## arg to use this option.
+if [[ "$1" == '--files' ]]; then
+   spell_check "${@:2}"
+   # If `--all` is passed, then any further arguments are ignored and the
+   # entire python directory is linted.
+elif [[ "$1" == '--all' ]]; then
+   spell_check_all
+else
+   # Check spelling only of the files that changed in last commit.
+   spell_check_changed
+fi
+echo 'Aphrodite codespell: Done'
+
+
+# Lint specified files
+lint() {
+    ruff "$@"
+}
 
-# Run Pylint
-echo 'Aphrodite Engine Pylint:'
-pylint aphrodite tests
+# Lint files that differ from main branch. Ignores dirs that are not slated
+# for autolint yet.
+lint_changed() {
+    # The `if` guard ensures that the list of filenames is not empty, which
+    # could cause ruff to receive 0 positional arguments, making it hang
+    # waiting for STDIN.
+    #
+    # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
+    # exist on both branches.
+    MERGEBASE="$(git merge-base origin/main HEAD)"
+
+    if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
+        git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
+             ruff
+    fi
+
+}
+
+# Run Ruff
+echo 'Aphrodite ruff:'
+### This flag lints individual files. --files *must* be the first command line
+### arg to use this option.
+if [[ "$1" == '--files' ]]; then
+   lint "${@:2}"
+   # If `--all` is passed, then any further arguments are ignored and the
+   # entire python directory is linted.
+elif [[ "$1" == '--all' ]]; then
+   lint aphrodite tests
+else
+   # Format only the files that changed in last commit.
+   lint_changed
+fi
 
 if ! git diff --quiet &>/dev/null; then
     echo 'Reformatted files. Please review and stage the changes.'

+ 45 - 0
pyproject.toml

@@ -7,3 +7,48 @@ requires = [
     "wheel",
 ]
 build-backend = "setuptools.build_meta"
+
+
+[tool.ruff]
+# Allow lines to be as long as 80.
+line-length = 80
+
+[tool.ruff.lint]
+select = [
+    # pycodestyle
+    "E",
+    # Pyflakes
+    "F",
+    # pyupgrade
+    # "UP",
+    # flake8-bugbear
+    "B",
+    # flake8-simplify
+    "SIM",
+    # isort
+    # "I",
+]
+ignore = [
+    # wildcard imports
+    "F405", "F403",
+    # lambda expression assignment
+    "E731",
+    # .strip() with multi-character strings
+    "B005",
+    # Loop control variable not used within loop body
+    "B007",
+]
+
+[tool.mypy]
+python_version = "3.8"
+
+ignore_missing_imports = true
+
+files = "aphrodite"
+# TODO: Include the code from Megatron and HuggingFace.
+exclude = "aphrodite/modeling/megatron/|aphrodite/modeling/models/|aphrodite/endpoints/kobold/klite.embd"
+
+
+[tool.codespell]
+ignore-words-list = "dout, te, indicies"
+skip = "./aphrodite/endpoints/kobold/klite.embd"

+ 15 - 2
requirements-dev.txt

@@ -1,6 +1,9 @@
 # formatting
 yapf==0.32.0
-pylint==2.8.2
+toml==0.10.2
+tomli==2.0.1
+ruff==0.1.5
+codespell==2.2.6
 
 # type checking
 mypy==0.991
@@ -11,4 +14,14 @@ types-setuptools
 # testing
 pytest
 pytest-forked
-pytest-asyncio
+pytest-asyncio
+pytest-rerunfailures
+httpx
+einops # required for MPT
+openai
+requests
+ray
+peft
+
+# Benchmarking
+aiohttp

+ 56 - 36
setup.py

@@ -11,22 +11,22 @@ from packaging.version import parse, Version
 import setuptools
 import torch
 import torch.utils.cpp_extension as torch_cpp_ext
-from torch.utils.cpp_extension import (
-    BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME)
+from torch.utils.cpp_extension import (BuildExtension, CUDAExtension,
+                                       CUDA_HOME, ROCM_HOME)
 
 ROOT_DIR = os.path.dirname(__file__)
 
 MAIN_CUDA_VERSION = "12.1"
 
 # Supported NVIDIA GPU architectures.
-NVIDIA_SUPPORTED_ARCHS = {
-    "6.1", "7.0", "7.5", "8.0", "8.6", "8.9", "9.0"
-}
+NVIDIA_SUPPORTED_ARCHS = {"6.1", "7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
 ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942", "gfx1100"}
 
+
 def _is_hip() -> bool:
     return torch.version.hip is not None
 
+
 def _is_cuda() -> bool:
     return torch.version.cuda is not None
 
@@ -36,7 +36,6 @@ CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
 # TODO: Should we use -O3?
 NVCC_FLAGS = ["-O2", "-std=c++17"]
 
-
 if _is_hip():
     if ROCM_HOME is None:
         raise RuntimeError(
@@ -74,10 +73,12 @@ def get_hipcc_rocm_version():
         print("Could not find HIP version in the output")
         return None
 
+
 def glob(pattern: str):
     root = Path(__name__).parent
     return [str(p) for p in root.glob(pattern)]
 
+
 def get_nvcc_cuda_version(cuda_dir: str) -> Version:
     """Get the CUDA version from nvcc.
 
@@ -90,10 +91,12 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
     nvcc_cuda_version = parse(output[release_idx].split(",")[0])
     return nvcc_cuda_version
 
+
 def get_pytorch_rocm_arch() -> Set[str]:
     env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None)
 
-    # If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator
+    # If we don't have PYTORCH_ROCM_ARCH specified pull the list from
+    # rocm_agent_enumerator
     if env_arch_list is None:
         command = "rocm_agent_enumerator"
         env_arch_list = subprocess.check_output([command]).decode('utf-8')\
@@ -124,6 +127,7 @@ def get_pytorch_rocm_arch() -> Set[str]:
             stacklevel=2)
     return arch_list
 
+
 def get_torch_arch_list() -> Set[str]:
     # TORCH_CUDA_ARCH_LIST can have one or more architectures,
     # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
@@ -202,11 +206,11 @@ if _is_cuda():
             "CUDA 11.1 or higher is required for compute capability 8.6.")
     if nvcc_cuda_version < Version("11.8"):
         if any(cc.startswith("8.9") for cc in compute_capabilities):
-            # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
-            # However, GPUs with compute capability 8.9 can also run the code generated by
-            # the previous versions of CUDA 11 and targeting compute capability 8.0.
-            # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
-            # instead of 8.9.
+            # CUDA 11.8 is required to generate the code targeting compute
+            # capability 8.9. However, GPUs with compute capability 8.9 can
+            # also run the code generated by the previous versions of CUDA 11
+            # and targeting compute capability 8.0. Therefore, if CUDA 11.8 is
+            # not available, we target compute capability 8.0 instead of 8.9.
             warnings.warn(
                 "CUDA 11.8 or higher is required for compute capability 8.9. "
                 "Targeting compute capability 8.0 instead.",
@@ -242,7 +246,7 @@ if _is_cuda():
         nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
         num_threads = min(os.cpu_count(), nvcc_threads)
         NVCC_FLAGS += ["--threads", str(num_threads)]
-    
+
     if nvcc_cuda_version >= Version("11.8"):
         NVCC_FLAGS += ["-DENABLE_FP8_E5M2"]
 
@@ -258,7 +262,8 @@ if _is_cuda():
         with contextlib.suppress(ValueError):
             torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
 
-    install_punica = bool(int(os.getenv("APHRODITE_INSTALL_PUNICA_KERNELS", "1")))
+    install_punica = bool(
+        int(os.getenv("APHRODITE_INSTALL_PUNICA_KERNELS", "1")))
     device_count = torch.cuda.device_count()
     for i in range(device_count):
         major, minor = torch.cuda.get_device_capability(i)
@@ -276,8 +281,9 @@ if _is_cuda():
                     "nvcc": NVCC_FLAGS_PUNICA,
                 },
             ))
-    
-    install_hadamard = bool(int(os.getenv("APHRODITE_INSTALL_HADAMARD_KERNELS", "1")))
+
+    install_hadamard = bool(
+        int(os.getenv("APHRODITE_INSTALL_HADAMARD_KERNELS", "1")))
     device_count = torch.cuda.device_count()
     for i in range(device_count):
         major, minor = torch.cuda.get_device_capability(i)
@@ -288,8 +294,10 @@ if _is_cuda():
         ext_modules.append(
             CUDAExtension(
                 name="aphrodite._hadamard_C",
-                sources=["kernels/hadamard/fast_hadamard_transform.cpp",
-                         "kernels/hadamard/fast_hadamard_transform_cuda.cu"],
+                sources=[
+                    "kernels/hadamard/fast_hadamard_transform.cpp",
+                    "kernels/hadamard/fast_hadamard_transform_cuda.cu"
+                ],
                 extra_compile_args={
                     "cxx": CXX_FLAGS,
                     "nvcc": NVCC_FLAGS,
@@ -313,17 +321,25 @@ aphrodite_extension_sources = [
 ]
 
 if _is_cuda():
-    aphrodite_extension_sources.append("kernels/quantization/awq/gemm_kernels.cu")
-    aphrodite_extension_sources.append("kernels/quantization/quip/origin_order.cu")
-    aphrodite_extension_sources.append("kernels/quantization/marlin/marlin_cuda_kernel.cu")
-    aphrodite_extension_sources.append("kernels/all_reduce/custom_all_reduce.cu")
-    aphrodite_extension_sources.append("kernels/quantization/aqlm/aqlm_cuda_entry.cpp")
-    aphrodite_extension_sources.append("kernels/quantization/aqlm/aqlm_cuda_kernel.cu")
+    aphrodite_extension_sources.append(
+        "kernels/quantization/awq/gemm_kernels.cu")
+    aphrodite_extension_sources.append(
+        "kernels/quantization/quip/origin_order.cu")
+    aphrodite_extension_sources.append(
+        "kernels/quantization/marlin/marlin_cuda_kernel.cu")
+    aphrodite_extension_sources.append(
+        "kernels/all_reduce/custom_all_reduce.cu")
+    aphrodite_extension_sources.append(
+        "kernels/quantization/aqlm/aqlm_cuda_entry.cpp")
+    aphrodite_extension_sources.append(
+        "kernels/quantization/aqlm/aqlm_cuda_kernel.cu")
     aphrodite_extension_sources.append(
         "kernels/quantization/bitsandbytes/int4_fp16_gemm_kernels.cu")
-    aphrodite_extension_sources.append("kernels/quantization/bitsandbytes/format.cu")
-    aphrodite_extension_sources.append("kernels/quantization/bitsandbytes/gemm_s4_f16.cu")
-    
+    aphrodite_extension_sources.append(
+        "kernels/quantization/bitsandbytes/format.cu")
+    aphrodite_extension_sources.append(
+        "kernels/quantization/bitsandbytes/gemm_s4_f16.cu")
+
     ext_modules.append(
         CUDAExtension(
             name="aphrodite._moe_C",
@@ -333,7 +349,7 @@ if _is_cuda():
                 "nvcc": NVCC_FLAGS,
             },
         ))
-    
+
 aphrodite_extension = CUDAExtension(
     name="aphrodite._C",
     sources=aphrodite_extension_sources,
@@ -341,13 +357,18 @@ aphrodite_extension = CUDAExtension(
         "cxx": CXX_FLAGS,
         "nvcc": NVCC_FLAGS,
     },
-    libraries=["cuda", "conda/envs/aphrodite-runtime/lib",
-               "conda/envs/aphrodite-runtime/lib/stubs"] if _is_cuda() else [],
-    library_dirs=["conda/envs/aphrodite-runtime/lib",
-                  "conda/envs/aphrodite-runtime/lib/stubs"] if _is_cuda() else [],
+    libraries=[
+        "cuda", "conda/envs/aphrodite-runtime/lib",
+        "conda/envs/aphrodite-runtime/lib/stubs"
+    ] if _is_cuda() else [],
+    library_dirs=[
+        "conda/envs/aphrodite-runtime/lib",
+        "conda/envs/aphrodite-runtime/lib/stubs"
+    ] if _is_cuda() else [],
 )
 ext_modules.append(aphrodite_extension)
 
+
 def get_path(*filepath) -> str:
     return os.path.join(ROOT_DIR, *filepath)
 
@@ -367,7 +388,7 @@ def find_version(filepath: str) -> str:
 
 def get_aphrodite_version() -> str:
     version = find_version(get_path("aphrodite", "__init__.py"))
-    
+
     if _is_hip():
         # get the HIP version
 
@@ -424,7 +445,7 @@ setuptools.setup(
         "Programming Language :: Python :: 3.9",
         "Programming Language :: Python :: 3.10",
         "Programming Language :: Python :: 3.11",
-        "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
+        "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",  # noqa: E501
         "Topic :: Scientific/Engineering :: Artificial Intelligence",
     ],
     packages=setuptools.find_packages(exclude=("kernels", "examples",
@@ -436,8 +457,7 @@ setuptools.setup(
     package_data={
         "aphrodite": [
             "endpoints/kobold/klite.embd",
-            "modeling/layers/quantization/hadamard.safetensors",
-            "py.typed"
+            "modeling/layers/quantization/hadamard.safetensors", "py.typed"
         ]
     },
     include_package_data=True,

+ 0 - 50
tests/async_engine/api_server_async_aphrodite.py

@@ -1,50 +0,0 @@
-"""aphrodite.endpoints.ooba.api_server with some extra logging for testing."""
-import argparse
-from typing import Any, Dict
-
-import uvicorn
-from fastapi.responses import JSONResponse, Response
-
-import aphrodite.endpoints.ooba.api_server
-from aphrodite.engine.args_tools import AsyncEngineArgs
-from aphrodite.engine.async_aphrodite import AsyncAphrodite
-
-app = aphrodite.endpoints.ooba.api_server.app
-
-
-class AsyncAphroditeWithStats(AsyncAphrodite):
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self._num_aborts = 0
-
-    async def abort(self, request_id: str) -> None:
-        await super().abort(request_id)
-        self._num_aborts += 1
-
-    def testing_stats(self) -> Dict[str, Any]:
-        return {"num_aborted_requests": self._num_aborts}
-
-
-@app.get("/stats")
-def stats() -> Response:
-    """Get the statistics of the engine."""
-    return JSONResponse(engine.testing_stats())
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--host", type=str, default="localhost")
-    parser.add_argument("--port", type=int, default=2242)
-    parser = AsyncEngineArgs.add_cli_args(parser)
-    args = parser.parse_args()
-
-    engine_args = AsyncEngineArgs.from_cli_args(args)
-    engine = AsyncAphroditeWithStats.from_engine_args(engine_args)
-    aphrodite.endpoints.ooba.api_server.engine = engine
-    uvicorn.run(app,
-                host=args.host,
-                port=args.port,
-                log_level="debug",
-                timeout_keep_alive=aphrodite.endpoints.ooba.api_server.
-                TIMEOUT_KEEP_ALIVE)

+ 0 - 93
tests/async_engine/test_api_server.py

@@ -1,93 +0,0 @@
-import subprocess
-import sys
-import time
-from multiprocessing import Pool
-from pathlib import Path
-
-import pytest
-import requests
-
-
-def _query_server(prompt: str, max_tokens: int = 5) -> dict:
-    response = requests.post("http://localhost:2242/generate",
-                             json={
-                                 "prompt": prompt,
-                                 "max_tokens": max_tokens,
-                                 "temperature": 0,
-                                 "ignore_eos": True
-                             })
-    response.raise_for_status()
-    return response.json()
-
-
-def _query_server_long(prompt: str) -> dict:
-    return _query_server(prompt, max_tokens=500)
-
-
-@pytest.fixture
-def api_server():
-    script_path = Path(__file__).parent.joinpath(
-        "api_server_async_engine.py").absolute()
-    uvicorn_process = subprocess.Popen([
-        sys.executable, "-u",
-        str(script_path), "--model", "EleutherAI/pythia-70m-deduped"
-    ])
-    yield
-    uvicorn_process.terminate()
-
-
-def test_api_server(api_server):
-    """
-    Run the API server and test it.
-
-    We run both the server and requests in separate processes.
-
-    We test that the server can handle incoming requests, including
-    multiple requests at the same time, and that it can handle requests
-    being cancelled without crashing.
-    """
-    with Pool(32) as pool:
-        # Wait until the server is ready
-        prompts = ["warm up"] * 1
-        result = None
-        while not result:
-            try:
-                for r in pool.map(_query_server, prompts):
-                    result = r
-                    break
-            except requests.exceptions.ConnectionError:
-                time.sleep(1)
-
-        # Actual tests start here
-        # Try with 1 prompt
-        for result in pool.map(_query_server, prompts):
-            assert result
-
-        num_aborted_requests = requests.get(
-            "http://localhost:8000/stats").json()["num_aborted_requests"]
-        assert num_aborted_requests == 0
-
-        # Try with 100 prompts
-        prompts = ["test prompt"] * 100
-        for result in pool.map(_query_server, prompts):
-            assert result
-
-    with Pool(32) as pool:
-        # Cancel requests
-        prompts = ["canceled requests"] * 100
-        pool.map_async(_query_server_long, prompts)
-        time.sleep(0.01)
-        pool.terminate()
-        pool.join()
-
-        # check cancellation stats
-        num_aborted_requests = requests.get(
-            "http://localhost:8000/stats").json()["num_aborted_requests"]
-        assert num_aborted_requests > 0
-
-    # check that server still runs after cancellations
-    with Pool(32) as pool:
-        # Try with 100 prompts
-        prompts = ["test prompt after canceled"] * 100
-        for result in pool.map(_query_server, prompts):
-            assert result

+ 0 - 91
tests/async_engine/test_async_aphrodite.py

@@ -1,91 +0,0 @@
-import asyncio
-from dataclasses import dataclass
-
-import pytest
-
-from aphrodite.engine.async_aphrodite import AsyncAphrodite
-
-
-@dataclass
-class RequestOutput:
-    request_id: int
-    finished: bool = False
-
-
-class MockEngine:
-
-    def __init__(self):
-        self.step_calls = 0
-        self.add_request_calls = 0
-        self.abort_request_calls = 0
-        self.request_id = None
-
-    async def step_async(self):
-        self.step_calls += 1
-        return [RequestOutput(
-            request_id=self.request_id)] if self.request_id else []
-
-    async def encode_request_async(self, *args, **kwargs):
-        pass
-
-    def generate(self, request_id):
-        self.request_id = request_id
-
-    def stop_generating(self):
-        self.request_id = None
-
-    def add_request(self, **kwargs):
-        del kwargs  # Unused
-        self.add_request_calls += 1
-
-    async def add_request_async(self, **kwargs):
-        self.add_request_calls += 1
-        return
-
-    def abort_request(self, request_id):
-        del request_id  # Unused
-        self.abort_request_calls += 1
-
-    def has_unfinished_requests(self):
-        return self.request_id is not None
-
-
-class MockAsyncAphrodite(AsyncAphrodite):
-
-    def _init_engine(self, *args, **kwargs):
-        return MockEngine()
-
-
-@pytest.mark.asyncio
-async def test_new_requests_event():
-    engine = MockAsyncAphrodite(worker_use_ray=False, engine_use_ray=False)
-    engine.start_background_loop()
-    await asyncio.sleep(0.01)
-    assert engine.engine.step_calls == 0
-
-    await engine.add_request("1", "", None)
-    await asyncio.sleep(0.01)
-    assert engine.engine.add_request_calls == 1
-    assert engine.engine.step_calls == 1
-
-    await engine.add_request("2", "", None)
-    engine.engine.generate("2")
-    await asyncio.sleep(0)
-    await asyncio.sleep(0)
-    assert engine.engine.add_request_calls == 2
-    assert engine.engine.step_calls >= 2
-    await asyncio.sleep(0.001)
-    assert engine.engine.step_calls >= 3
-    engine.engine.stop_generating()
-    await asyncio.sleep(0.001)
-    old_step_calls = engine.engine.step_calls
-    await asyncio.sleep(0.001)
-    assert engine.engine.step_calls == old_step_calls
-
-    await engine.add_request("3", "", None)
-    await asyncio.sleep(0.01)
-    assert engine.engine.add_request_calls == 3
-    assert engine.engine.step_calls == old_step_calls + 1
-    await asyncio.sleep(0.01)
-    assert engine.engine.add_request_calls == 3
-    assert engine.engine.step_calls == old_step_calls + 1

+ 0 - 119
tests/async_engine/test_openai_server.py

@@ -1,119 +0,0 @@
-from argparse import Namespace
-from dataclasses import dataclass
-
-import pytest
-from fastapi.testclient import TestClient
-
-from aphrodite.endpoints.openai.api_server import *
-
-# Define models, templates, and their corresponding expected outputs
-MODEL_TEMPLATE_GENERATON_OUTPUT = [
-    ("EleutherAI/pythia-70m-deduped", None, True,
-     "Hello</s>Hi there!</s>What is the capital of</s>"),
-    ("EleutherAI/pythia-70m-deduped", None, False,
-     "Hello</s>Hi there!</s>What is the capital of</s>"),
-    ("EleutherAI/pythia-70m-deduped", "../../examples/template_chatml.jinja",
-     True, """<|im_start|>user
-Hello<|im_end|>
-<|im_start|>assistant
-Hi there!<|im_end|>
-<|im_start|>user
-What is the capital of<|im_end|>
-<|im_start|>assistant
-"""),
-    ("EleutherAI/pythia-70m-deduped", "../../examples/template_chatml.jinja",
-     False, """<|im_start|>user
-Hello<|im_end|>
-<|im_start|>assistant
-Hi there!<|im_end|>
-<|im_start|>user
-What is the capital of""")
-]
-
-TEST_MESSAGES = [
-    {
-        'role': 'user',
-        'content': 'Hello'
-    },
-    {
-        'role': 'assistant',
-        'content': 'Hi there!'
-    },
-    {
-        'role': 'user',
-        'content': 'What is the capital of'
-    },
-]
-client = TestClient(app)
-
-
-@dataclass
-class MockTokenizer:
-    chat_template = None
-
-
-def test_load_chat_template():
-    # Testing chatml template
-    template = "../../examples/chatml_template.jinja"
-    mock_args = Namespace(chat_template=template)
-    tokenizer = MockTokenizer()
-
-    # Call the function with the mocked args
-    load_chat_template(mock_args, tokenizer)
-
-    template_content = tokenizer.chat_template
-
-    # Test assertions
-    assert template_content is not None
-    # Hard coded value for chatml_template.jinja
-    assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
-{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
-
-
-def test_no_load_chat_template():
-    # Testing chatml template
-    template = "../../examples/does_not_exist"
-    mock_args = Namespace(chat_template=template)
-    tokenizer = MockTokenizer()
-
-    # Call the function with the mocked args
-    load_chat_template(mock_args, tokenizer=tokenizer)
-    template_content = tokenizer.chat_template
-
-    # Test assertions
-    assert template_content is not None
-    # Hard coded value for chatml_template.jinja
-    assert template_content == """../../examples/does_not_exist"""
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize(
-    "model,template,add_generation_prompt,expected_output",
-    MODEL_TEMPLATE_GENERATON_OUTPUT)
-async def test_get_gen_prompt(model, template, add_generation_prompt,
-                              expected_output):
-    # Initialize the tokenizer
-    tokenizer = get_tokenizer(tokenizer_name=model)
-
-    mock_args = Namespace(chat_template=template)
-    load_chat_template(mock_args, tokenizer)
-
-    # Create a mock request object using keyword arguments
-    mock_request = ChatCompletionRequest(
-        model=model,
-        messages=TEST_MESSAGES,
-        add_generation_prompt=add_generation_prompt)
-
-    # Call the function and get the result
-    result = tokenizer.apply_chat_template(
-        conversation=mock_request.messages,
-        tokenize=False,
-        add_generation_prompt=mock_request.add_generation_prompt)
-
-    # Test assertion
-    assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
-
-
-def test_health_endpoint():
-    response = client.get("/health")
-    assert response.status_code == 200

+ 0 - 65
tests/async_engine/test_request_tracker.py

@@ -1,65 +0,0 @@
-import pytest
-
-from aphrodite.engine.async_aphrodite import RequestTracker
-from aphrodite.common.outputs import RequestOutput
-
-
-@pytest.mark.asyncio
-async def test_request_tracker():
-    tracker = RequestTracker()
-    stream_1 = tracker.add_request("1")
-    new, finished = tracker.get_new_and_finished_requests()
-    assert not tracker.new_requests_event.is_set()
-    assert len(new) == 1
-    assert new[0]["request_id"] == "1"
-    assert not finished
-    assert not stream_1.finished
-
-    stream_2 = tracker.add_request("2")
-    stream_3 = tracker.add_request("3")
-    assert tracker.new_requests_event.is_set()
-    await tracker.wait_for_new_requests()
-    new, finished = tracker.get_new_and_finished_requests()
-    assert not tracker.new_requests_event.is_set()
-    assert len(new) == 2
-    assert new[0]["request_id"] == "2"
-    assert new[1]["request_id"] == "3"
-    assert not finished
-    assert not stream_2.finished
-    assert not stream_3.finished
-
-    # request_ids must be unique
-    with pytest.raises(KeyError):
-        tracker.add_request("1")
-    assert not tracker.new_requests_event.is_set()
-
-    tracker.abort_request("1")
-    new, finished = tracker.get_new_and_finished_requests()
-    assert len(finished) == 1
-    assert "1" in finished
-    assert not new
-    assert stream_1.finished
-
-    stream_4 = tracker.add_request("4")
-    tracker.abort_request("4")
-    assert tracker.new_requests_event.is_set()
-    await tracker.wait_for_new_requests()
-    new, finished = tracker.get_new_and_finished_requests()
-    assert len(finished) == 1
-    assert "4" in finished
-    assert not new
-    assert stream_4.finished
-
-    stream_5 = tracker.add_request("5")
-    assert tracker.new_requests_event.is_set()
-    tracker.process_request_output(
-        RequestOutput("2", "output", [], [], [], finished=True))
-    await tracker.wait_for_new_requests()
-    new, finished = tracker.get_new_and_finished_requests()
-    assert not tracker.new_requests_event.is_set()
-    assert len(finished) == 1
-    assert "2" in finished
-    assert len(new) == 1
-    assert new[0]["request_id"] == "5"
-    assert stream_2.finished
-    assert not stream_5.finished

+ 6 - 3
tests/benchmarks/backend_request_func.py

@@ -111,7 +111,8 @@ async def async_request_aphrodite(
                             output.ttft = ttft
                     output.latency = time.perf_counter() - st
 
-                    # When streaming, '\0' is appended to the end of the response.
+                    # When streaming, '\0' is appended to the end of the
+                    # response.
                     body = data.decode("utf-8").strip("\0")
                     output.generated_text = json.loads(
                         body)["text"][0][len(request_func_input.prompt):]
@@ -160,7 +161,8 @@ async def async_request_vllm(
                             output.ttft = ttft
                     output.latency = time.perf_counter() - st
 
-                    # When streaming, '\0' is appended to the end of the response.
+                    # When streaming, '\0' is appended to the end of the
+                    # response.
                     body = data.decode("utf-8").strip("\0")
                     output.generated_text = json.loads(
                         body)["text"][0][len(request_func_input.prompt):]
@@ -242,7 +244,8 @@ async def async_request_deepspeed_mii(
         output = RequestFuncOutput()
         output.prompt_len = request_func_input.prompt_len
 
-        # DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use 0 as placeholder.
+        # DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use
+        # 0 as placeholder.
         # https://github.com/microsoft/DeepSpeed-MII/pull/311
         output.ttft = 0
 

+ 4 - 3
tests/benchmarks/serving.py

@@ -293,7 +293,8 @@ def main(args: argparse.Namespace):
 
         # Save to file
         base_model_id = model_id.split("/")[-1]
-        file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
+        file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-"
+        f"{current_dt}.json"
         with open(file_name, "w") as outfile:
             json.dump(result_json, outfile)
 
@@ -340,8 +341,8 @@ if __name__ == "__main__":
     parser.add_argument(
         "--tokenizer",
         type=str,
-        help=
-        "Name or path of the tokenizer, if not using the default model tokenizer.",
+        help="Name or path of the tokenizer, if not using the default model "
+        "tokenizer.",
     )
     parser.add_argument(
         "--best-of",

+ 9 - 7
tests/endpoints/test_openai_server.py

@@ -5,9 +5,9 @@ import time
 import sys
 import pytest
 import requests
-import ray  # using Ray for overall ease of process management, parallel requests, and debugging.
+import ray
 import openai  # use the official client for correctness check
-from huggingface_hub import snapshot_download  # downloading lora to test lora requests
+from huggingface_hub import snapshot_download
 
 # imports for guided decoding tests
 import json
@@ -17,8 +17,8 @@ import re
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
 
 MAX_SERVER_START_WAIT_S = 600  # wait for server to start for 60 seconds
-MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"  # any model with a chat template should work here
-LORA_NAME = "typeof/zephyr-7b-beta-lora"  # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
+MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
+LORA_NAME = "typeof/zephyr-7b-beta-lora"
 
 TEST_SCHEMA = {
     "type": "object",
@@ -121,7 +121,7 @@ def server(zephyr_lora_files):
         "--model",
         MODEL_NAME,
         "--dtype",
-        "bfloat16",  # use half precision for speed and memory savings in CI environment
+        "bfloat16",  # use half precision for speed and memory savings in CI env
         "--max-model-len",
         "8192",
         "--enforce-eager",
@@ -337,7 +337,8 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
         max_tokens=5,
         temperature=0.0,
         extra_body=dict(
-            # NOTE: this has to be true for n > 1 in Aphrodite, but not necessary for official client.
+            # NOTE: this has to be true for n > 1 in Aphrodite, but not
+            # necessary for official client.
             use_beam_search=True),
     )
     assert len(batch.choices) == 4
@@ -415,7 +416,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
     completion = await client.completions.create(
         model=MODEL_NAME,
         prompt=
-        f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}",
+        "Give an example JSON for an employee profile that fits this schema:"
+        f" {TEST_SCHEMA}",
         n=3,
         temperature=1.0,
         max_tokens=500,

+ 1 - 2
tests/engine/test_detokenize.py

@@ -5,9 +5,8 @@ from transformers import AutoTokenizer
 from aphrodite.transformers_utils.tokenizer import detokenize_incrementally
 
 TRUTH = [
-    # pylint: disable=line-too-long
     "Tell me your favorite story.",
-    "Transformers have revolutionized almost all natural language processing (NLP) tasks but suffer from memory and computational complexity that scales quadratically with sequence length. In contrast, recurrent neural networks (RNNs) exhibit linear scaling in memory and computational requirements but struggle to match the same performance as Transformers due to limitations in parallelization and scalability. We propose a novel model architecture, Receptance Weighted Key Value (RWKV), that combines the efficient parallelizable training of Transformers with the efficient inference of RNNs. Our approach leverages a linear attention mechanism and allows us to formulate the model as either a Transformer or an RNN, which parallelizes computations during training and maintains constant computational and memory complexity during inference, leading to the first non-transformer architecture to be scaled to tens of billions of parameters. Our experiments reveal that RWKV performs on par with similarly sized Transformers, suggesting that future work can leverage this architecture to create more efficient models. This work presents a significant step towards reconciling the trade-offs between computational efficiency and model performance in sequence processing tasks."
+    "Transformers have revolutionized almost all natural language processing (NLP) tasks but suffer from memory and computational complexity that scales quadratically with sequence length. In contrast, recurrent neural networks (RNNs) exhibit linear scaling in memory and computational requirements but struggle to match the same performance as Transformers due to limitations in parallelization and scalability. We propose a novel model architecture, Receptance Weighted Key Value (RWKV), that combines the efficient parallelizable training of Transformers with the efficient inference of RNNs. Our approach leverages a linear attention mechanism and allows us to formulate the model as either a Transformer or an RNN, which parallelizes computations during training and maintains constant computational and memory complexity during inference, leading to the first non-transformer architecture to be scaled to tens of billions of parameters. Our experiments reveal that RWKV performs on par with similarly sized Transformers, suggesting that future work can leverage this architecture to create more efficient models. This work presents a significant step towards reconciling the trade-offs between computational efficiency and model performance in sequence processing tasks."  # noqa: E501
     "トランスフォーマーは、ほぼすべての自然言語処理に革命をもたらしました",
 ]