|
@@ -1,13 +1,14 @@
|
|
|
import contextlib
|
|
|
+import functools
|
|
|
from typing import List, Optional, Tuple, Type
|
|
|
|
|
|
import torch
|
|
|
+from loguru import logger
|
|
|
|
|
|
try:
|
|
|
import aphrodite._C
|
|
|
except ImportError as e:
|
|
|
- from loguru import logger
|
|
|
- logger.warning("Failed to import from vllm._C with %r", e)
|
|
|
+ logger.warning(f"Failed to import from aphrodite._C with {e}")
|
|
|
|
|
|
with contextlib.suppress(ImportError):
|
|
|
import aphrodite._moe_C
|
|
@@ -22,6 +23,25 @@ def is_custom_op_supported(op_name: str) -> bool:
|
|
|
return op is not None
|
|
|
|
|
|
|
|
|
+def hint_on_error(fn):
|
|
|
+
|
|
|
+ @functools.wraps(fn)
|
|
|
+ def wrapper(*args, **kwargs):
|
|
|
+ try:
|
|
|
+ return fn(*args, **kwargs)
|
|
|
+ except AttributeError as e:
|
|
|
+ msg = (
|
|
|
+ f"Error in calling custom op {fn.__name__}: {e}\n"
|
|
|
+ f"Possibly you have built or installed an obsolete version of aphrodite.\n"
|
|
|
+ f"Please try a clean build and install of aphrodite,"
|
|
|
+ f"or remove old built files such as aphrodite/*.so and build/ ."
|
|
|
+ )
|
|
|
+ logger.error(msg)
|
|
|
+ raise e
|
|
|
+
|
|
|
+ return wrapper
|
|
|
+
|
|
|
+
|
|
|
# activation ops
|
|
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
|
torch.ops._C.silu_and_mul(out, x)
|
|
@@ -475,3 +495,24 @@ def dispatch_bgmv_low_level(
|
|
|
h_out,
|
|
|
y_offset,
|
|
|
)
|
|
|
+
|
|
|
+
|
|
|
+# TODO: remove this later
|
|
|
+names_and_values = globals()
|
|
|
+names_and_values_to_update = {}
|
|
|
+# prepare variables to avoid dict size change during iteration
|
|
|
+k, v, arg = None, None, None
|
|
|
+fn_type = type(lambda x: x)
|
|
|
+for k, v in names_and_values.items():
|
|
|
+ # find functions that are defined in this file and have torch.Tensor
|
|
|
+ # in their annotations. `arg == "torch.Tensor"` is used to handle
|
|
|
+ # the case when users use `import __annotations__` to turn type
|
|
|
+ # hints into strings.
|
|
|
+ if isinstance(v, fn_type) \
|
|
|
+ and v.__code__.co_filename == __file__ \
|
|
|
+ and any(arg is torch.Tensor or arg == "torch.Tensor"
|
|
|
+ for arg in v.__annotations__.values()):
|
|
|
+ names_and_values_to_update[k] = hint_on_error(v)
|
|
|
+
|
|
|
+names_and_values.update(names_and_values_to_update)
|
|
|
+del names_and_values_to_update, names_and_values, v, k, fn_type
|