123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646 |
- package securecookie
- import (
- "bytes"
- "crypto/aes"
- "crypto/cipher"
- "crypto/hmac"
- "crypto/rand"
- "crypto/sha256"
- "crypto/subtle"
- "encoding/base64"
- "encoding/gob"
- "encoding/json"
- "fmt"
- "hash"
- "io"
- "strconv"
- "strings"
- "time"
- )
- type Error interface {
- error
-
-
-
-
- IsUsage() bool
-
-
-
-
- IsDecode() bool
-
-
- IsInternal() bool
-
-
-
-
-
-
-
-
- Cause() error
- }
- type errorType int
- const (
- usageError = errorType(1 << iota)
- decodeError
- internalError
- )
- type cookieError struct {
- typ errorType
- msg string
- cause error
- }
- func (e cookieError) IsUsage() bool { return (e.typ & usageError) != 0 }
- func (e cookieError) IsDecode() bool { return (e.typ & decodeError) != 0 }
- func (e cookieError) IsInternal() bool { return (e.typ & internalError) != 0 }
- func (e cookieError) Cause() error { return e.cause }
- func (e cookieError) Error() string {
- parts := []string{"securecookie: "}
- if e.msg == "" {
- parts = append(parts, "error")
- } else {
- parts = append(parts, e.msg)
- }
- if c := e.Cause(); c != nil {
- parts = append(parts, " - caused by: ", c.Error())
- }
- return strings.Join(parts, "")
- }
- var (
- errGeneratingIV = cookieError{typ: internalError, msg: "failed to generate random iv"}
- errNoCodecs = cookieError{typ: usageError, msg: "no codecs provided"}
- errHashKeyNotSet = cookieError{typ: usageError, msg: "hash key is not set"}
- errBlockKeyNotSet = cookieError{typ: usageError, msg: "block key is not set"}
- errEncodedValueTooLong = cookieError{typ: usageError, msg: "the value is too long"}
- errValueToDecodeTooLong = cookieError{typ: decodeError, msg: "the value is too long"}
- errTimestampInvalid = cookieError{typ: decodeError, msg: "invalid timestamp"}
- errTimestampTooNew = cookieError{typ: decodeError, msg: "timestamp is too new"}
- errTimestampExpired = cookieError{typ: decodeError, msg: "expired timestamp"}
- errDecryptionFailed = cookieError{typ: decodeError, msg: "the value could not be decrypted"}
- errValueNotByte = cookieError{typ: decodeError, msg: "value not a []byte."}
- errValueNotBytePtr = cookieError{typ: decodeError, msg: "value not a pointer to []byte."}
-
-
-
-
-
- ErrMacInvalid = cookieError{typ: decodeError, msg: "the value is not valid"}
- )
- type Codec interface {
- Encode(name string, value interface{}) (string, error)
- Decode(name, value string, dst interface{}) error
- }
- func New(hashKey, blockKey []byte) *SecureCookie {
- s := &SecureCookie{
- hashKey: hashKey,
- blockKey: blockKey,
- hashFunc: sha256.New,
- maxAge: 86400 * 30,
- maxLength: 4096,
- sz: GobEncoder{},
- }
- if hashKey == nil {
- s.err = errHashKeyNotSet
- }
- if blockKey != nil {
- s.BlockFunc(aes.NewCipher)
- }
- return s
- }
- type SecureCookie struct {
- hashKey []byte
- hashFunc func() hash.Hash
- blockKey []byte
- block cipher.Block
- maxLength int
- maxAge int64
- minAge int64
- err error
- sz Serializer
-
-
- timeFunc func() int64
- }
- type Serializer interface {
- Serialize(src interface{}) ([]byte, error)
- Deserialize(src []byte, dst interface{}) error
- }
- type GobEncoder struct{}
- type JSONEncoder struct{}
- type NopEncoder struct{}
- func (s *SecureCookie) MaxLength(value int) *SecureCookie {
- s.maxLength = value
- return s
- }
- func (s *SecureCookie) MaxAge(value int) *SecureCookie {
- s.maxAge = int64(value)
- return s
- }
- func (s *SecureCookie) MinAge(value int) *SecureCookie {
- s.minAge = int64(value)
- return s
- }
- func (s *SecureCookie) HashFunc(f func() hash.Hash) *SecureCookie {
- s.hashFunc = f
- return s
- }
- func (s *SecureCookie) BlockFunc(f func([]byte) (cipher.Block, error)) *SecureCookie {
- if s.blockKey == nil {
- s.err = errBlockKeyNotSet
- } else if block, err := f(s.blockKey); err == nil {
- s.block = block
- } else {
- s.err = cookieError{cause: err, typ: usageError}
- }
- return s
- }
- func (s *SecureCookie) SetSerializer(sz Serializer) *SecureCookie {
- s.sz = sz
- return s
- }
- func (s *SecureCookie) Encode(name string, value interface{}) (string, error) {
- if s.err != nil {
- return "", s.err
- }
- if s.hashKey == nil {
- s.err = errHashKeyNotSet
- return "", s.err
- }
- var err error
- var b []byte
-
- if b, err = s.sz.Serialize(value); err != nil {
- return "", cookieError{cause: err, typ: usageError}
- }
-
- if s.block != nil {
- if b, err = encrypt(s.block, b); err != nil {
- return "", cookieError{cause: err, typ: usageError}
- }
- }
- b = encode(b)
-
- b = []byte(fmt.Sprintf("%s|%d|%s|", name, s.timestamp(), b))
- mac := createMac(hmac.New(s.hashFunc, s.hashKey), b[:len(b)-1])
-
- b = append(b, mac...)[len(name)+1:]
-
- b = encode(b)
-
- if s.maxLength != 0 && len(b) > s.maxLength {
- return "", errEncodedValueTooLong
- }
-
- return string(b), nil
- }
- func (s *SecureCookie) Decode(name, value string, dst interface{}) error {
- if s.err != nil {
- return s.err
- }
- if s.hashKey == nil {
- s.err = errHashKeyNotSet
- return s.err
- }
-
- if s.maxLength != 0 && len(value) > s.maxLength {
- return errValueToDecodeTooLong
- }
-
- b, err := decode([]byte(value))
- if err != nil {
- return err
- }
-
- parts := bytes.SplitN(b, []byte("|"), 3)
- if len(parts) != 3 {
- return ErrMacInvalid
- }
- h := hmac.New(s.hashFunc, s.hashKey)
- b = append([]byte(name+"|"), b[:len(b)-len(parts[2])-1]...)
- if err = verifyMac(h, b, parts[2]); err != nil {
- return err
- }
-
- var t1 int64
- if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
- return errTimestampInvalid
- }
- t2 := s.timestamp()
- if s.minAge != 0 && t1 > t2-s.minAge {
- return errTimestampTooNew
- }
- if s.maxAge != 0 && t1 < t2-s.maxAge {
- return errTimestampExpired
- }
-
- b, err = decode(parts[1])
- if err != nil {
- return err
- }
- if s.block != nil {
- if b, err = decrypt(s.block, b); err != nil {
- return err
- }
- }
-
- if err = s.sz.Deserialize(b, dst); err != nil {
- return cookieError{cause: err, typ: decodeError}
- }
-
- return nil
- }
- func (s *SecureCookie) timestamp() int64 {
- if s.timeFunc == nil {
- return time.Now().UTC().Unix()
- }
- return s.timeFunc()
- }
- func createMac(h hash.Hash, value []byte) []byte {
- h.Write(value)
- return h.Sum(nil)
- }
- func verifyMac(h hash.Hash, value []byte, mac []byte) error {
- mac2 := createMac(h, value)
-
-
- if len(mac) == len(mac2) && subtle.ConstantTimeCompare(mac, mac2) == 1 {
- return nil
- }
- return ErrMacInvalid
- }
- func encrypt(block cipher.Block, value []byte) ([]byte, error) {
- iv := GenerateRandomKey(block.BlockSize())
- if iv == nil {
- return nil, errGeneratingIV
- }
-
- stream := cipher.NewCTR(block, iv)
- stream.XORKeyStream(value, value)
-
- return append(iv, value...), nil
- }
- func decrypt(block cipher.Block, value []byte) ([]byte, error) {
- size := block.BlockSize()
- if len(value) > size {
-
- iv := value[:size]
-
- value = value[size:]
-
- stream := cipher.NewCTR(block, iv)
- stream.XORKeyStream(value, value)
- return value, nil
- }
- return nil, errDecryptionFailed
- }
- func (e GobEncoder) Serialize(src interface{}) ([]byte, error) {
- buf := new(bytes.Buffer)
- enc := gob.NewEncoder(buf)
- if err := enc.Encode(src); err != nil {
- return nil, cookieError{cause: err, typ: usageError}
- }
- return buf.Bytes(), nil
- }
- func (e GobEncoder) Deserialize(src []byte, dst interface{}) error {
- dec := gob.NewDecoder(bytes.NewBuffer(src))
- if err := dec.Decode(dst); err != nil {
- return cookieError{cause: err, typ: decodeError}
- }
- return nil
- }
- func (e JSONEncoder) Serialize(src interface{}) ([]byte, error) {
- buf := new(bytes.Buffer)
- enc := json.NewEncoder(buf)
- if err := enc.Encode(src); err != nil {
- return nil, cookieError{cause: err, typ: usageError}
- }
- return buf.Bytes(), nil
- }
- func (e JSONEncoder) Deserialize(src []byte, dst interface{}) error {
- dec := json.NewDecoder(bytes.NewReader(src))
- if err := dec.Decode(dst); err != nil {
- return cookieError{cause: err, typ: decodeError}
- }
- return nil
- }
- func (e NopEncoder) Serialize(src interface{}) ([]byte, error) {
- if b, ok := src.([]byte); ok {
- return b, nil
- }
- return nil, errValueNotByte
- }
- func (e NopEncoder) Deserialize(src []byte, dst interface{}) error {
- if dat, ok := dst.(*[]byte); ok {
- *dat = src
- return nil
- }
- return errValueNotBytePtr
- }
- func encode(value []byte) []byte {
- encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
- base64.URLEncoding.Encode(encoded, value)
- return encoded
- }
- func decode(value []byte) ([]byte, error) {
- decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
- b, err := base64.URLEncoding.Decode(decoded, value)
- if err != nil {
- return nil, cookieError{cause: err, typ: decodeError, msg: "base64 decode failed"}
- }
- return decoded[:b], nil
- }
- func GenerateRandomKey(length int) []byte {
- k := make([]byte, length)
- if _, err := io.ReadFull(rand.Reader, k); err != nil {
- return nil
- }
- return k
- }
- func CodecsFromPairs(keyPairs ...[]byte) []Codec {
- codecs := make([]Codec, len(keyPairs)/2+len(keyPairs)%2)
- for i := 0; i < len(keyPairs); i += 2 {
- var blockKey []byte
- if i+1 < len(keyPairs) {
- blockKey = keyPairs[i+1]
- }
- codecs[i/2] = New(keyPairs[i], blockKey)
- }
- return codecs
- }
- func EncodeMulti(name string, value interface{}, codecs ...Codec) (string, error) {
- if len(codecs) == 0 {
- return "", errNoCodecs
- }
- var errors MultiError
- for _, codec := range codecs {
- encoded, err := codec.Encode(name, value)
- if err == nil {
- return encoded, nil
- }
- errors = append(errors, err)
- }
- return "", errors
- }
- func DecodeMulti(name string, value string, dst interface{}, codecs ...Codec) error {
- if len(codecs) == 0 {
- return errNoCodecs
- }
- var errors MultiError
- for _, codec := range codecs {
- err := codec.Decode(name, value, dst)
- if err == nil {
- return nil
- }
- errors = append(errors, err)
- }
- return errors
- }
- type MultiError []error
- func (m MultiError) IsUsage() bool { return m.any(func(e Error) bool { return e.IsUsage() }) }
- func (m MultiError) IsDecode() bool { return m.any(func(e Error) bool { return e.IsDecode() }) }
- func (m MultiError) IsInternal() bool { return m.any(func(e Error) bool { return e.IsInternal() }) }
- func (m MultiError) Cause() error { return nil }
- func (m MultiError) Error() string {
- s, n := "", 0
- for _, e := range m {
- if e != nil {
- if n == 0 {
- s = e.Error()
- }
- n++
- }
- }
- switch n {
- case 0:
- return "(0 errors)"
- case 1:
- return s
- case 2:
- return s + " (and 1 other error)"
- }
- return fmt.Sprintf("%s (and %d other errors)", s, n-1)
- }
- func (m MultiError) any(pred func(Error) bool) bool {
- for _, e := range m {
- if ourErr, ok := e.(Error); ok && pred(ourErr) {
- return true
- }
- }
- return false
- }
|