|
@@ -367,4 +367,11 @@ def initialize_dummy_weights(
|
|
|
"""
|
|
|
for param in model.state_dict().values():
|
|
|
if torch.is_floating_point(param):
|
|
|
- param.data.uniform_(low, high)
|
|
|
+ if torch.finfo(param.data.dtype).bits < 16:
|
|
|
+ # uniform_ doesn't support < 16-bit datatypes (FP8)
|
|
|
+ dtype = param.data.dtype
|
|
|
+ tmp_param = param.data.to(torch.float16)
|
|
|
+ tmp_param = tmp_param.uniform_(low, high).to(dtype)
|
|
|
+ param.data.copy_(tmp_param)
|
|
|
+ else:
|
|
|
+ param.uniform_(low, high)
|