baichuan.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # coding=utf-8
  2. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from transformers.configuration_utils import PretrainedConfig
  21. class BaiChuanConfig(PretrainedConfig):
  22. model_type = "baichuan"
  23. keys_to_ignore_at_inference = ["past_key_values"]
  24. def __init__(
  25. self,
  26. vocab_size=64000,
  27. hidden_size=4096,
  28. intermediate_size=11008,
  29. num_hidden_layers=32,
  30. num_attention_heads=32,
  31. hidden_act="silu",
  32. max_position_embeddings=4096,
  33. initializer_range=0.02,
  34. rms_norm_eps=1e-6,
  35. use_cache=True,
  36. pad_token_id=0,
  37. bos_token_id=1,
  38. eos_token_id=2,
  39. tie_word_embeddings=False,
  40. **kwargs,
  41. ):
  42. self.vocab_size = vocab_size
  43. self.max_position_embeddings = max_position_embeddings
  44. self.hidden_size = hidden_size
  45. self.intermediate_size = intermediate_size
  46. self.num_hidden_layers = num_hidden_layers
  47. self.num_attention_heads = num_attention_heads
  48. self.hidden_act = hidden_act
  49. self.initializer_range = initializer_range
  50. self.rms_norm_eps = rms_norm_eps
  51. self.use_cache = use_cache
  52. super().__init__(
  53. pad_token_id=pad_token_id,
  54. bos_token_id=bos_token_id,
  55. eos_token_id=eos_token_id,
  56. tie_word_embeddings=tie_word_embeddings,
  57. **kwargs,
  58. )