header.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. package wire
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "errors"
  6. "fmt"
  7. "github.com/lucas-clemente/quic-go/internal/protocol"
  8. "github.com/lucas-clemente/quic-go/internal/utils"
  9. )
  10. // Header is the header of a QUIC packet.
  11. // It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header.
  12. type Header struct {
  13. IsPublicHeader bool
  14. Raw []byte
  15. Version protocol.VersionNumber
  16. DestConnectionID protocol.ConnectionID
  17. SrcConnectionID protocol.ConnectionID
  18. OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
  19. PacketNumberLen protocol.PacketNumberLen
  20. PacketNumber protocol.PacketNumber
  21. IsVersionNegotiation bool
  22. SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
  23. // only needed for the gQUIC Public Header
  24. VersionFlag bool
  25. ResetFlag bool
  26. DiversificationNonce []byte
  27. // only needed for the IETF Header
  28. Type protocol.PacketType
  29. IsLongHeader bool
  30. KeyPhase int
  31. PayloadLen protocol.ByteCount
  32. Token []byte
  33. }
  34. var errInvalidPacketNumberLen = errors.New("invalid packet number length")
  35. // Write writes the Header.
  36. func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
  37. if !ver.UsesIETFHeaderFormat() {
  38. h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
  39. return h.writePublicHeader(b, pers, ver)
  40. }
  41. // write an IETF QUIC header
  42. if h.IsLongHeader {
  43. return h.writeLongHeader(b, ver)
  44. }
  45. return h.writeShortHeader(b, ver)
  46. }
  47. // TODO: add support for the key phase
  48. func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
  49. b.WriteByte(byte(0x80 | h.Type))
  50. utils.BigEndian.WriteUint32(b, uint32(h.Version))
  51. connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
  52. if err != nil {
  53. return err
  54. }
  55. b.WriteByte(connIDLen)
  56. b.Write(h.DestConnectionID.Bytes())
  57. b.Write(h.SrcConnectionID.Bytes())
  58. if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
  59. utils.WriteVarInt(b, uint64(len(h.Token)))
  60. b.Write(h.Token)
  61. }
  62. if h.Type == protocol.PacketTypeRetry {
  63. odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
  64. if err != nil {
  65. return err
  66. }
  67. // randomize the first 4 bits
  68. odcilByte := make([]byte, 1)
  69. _, _ = rand.Read(odcilByte) // it's safe to ignore the error here
  70. odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
  71. b.Write(odcilByte)
  72. b.Write(h.OrigDestConnectionID.Bytes())
  73. b.Write(h.Token)
  74. return nil
  75. }
  76. if v.UsesLengthInHeader() {
  77. utils.WriteVarInt(b, uint64(h.PayloadLen))
  78. }
  79. if v.UsesVarintPacketNumbers() {
  80. return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
  81. }
  82. utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
  83. if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
  84. if len(h.DiversificationNonce) != 32 {
  85. return errors.New("invalid diversification nonce length")
  86. }
  87. b.Write(h.DiversificationNonce)
  88. }
  89. return nil
  90. }
  91. func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
  92. typeByte := byte(0x30)
  93. typeByte |= byte(h.KeyPhase << 6)
  94. if !v.UsesVarintPacketNumbers() {
  95. switch h.PacketNumberLen {
  96. case protocol.PacketNumberLen1:
  97. case protocol.PacketNumberLen2:
  98. typeByte |= 0x1
  99. case protocol.PacketNumberLen4:
  100. typeByte |= 0x2
  101. default:
  102. return errInvalidPacketNumberLen
  103. }
  104. }
  105. b.WriteByte(typeByte)
  106. b.Write(h.DestConnectionID.Bytes())
  107. if !v.UsesVarintPacketNumbers() {
  108. switch h.PacketNumberLen {
  109. case protocol.PacketNumberLen1:
  110. b.WriteByte(uint8(h.PacketNumber))
  111. case protocol.PacketNumberLen2:
  112. utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
  113. case protocol.PacketNumberLen4:
  114. utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
  115. }
  116. return nil
  117. }
  118. return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
  119. }
  120. // writePublicHeader writes a Public Header.
  121. func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
  122. if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) {
  123. return errors.New("PublicHeader: Can only write regular packets")
  124. }
  125. if h.SrcConnectionID.Len() != 0 {
  126. return errors.New("PublicHeader: SrcConnectionID must not be set")
  127. }
  128. if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
  129. return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
  130. }
  131. publicFlagByte := uint8(0x00)
  132. if h.VersionFlag {
  133. publicFlagByte |= 0x01
  134. }
  135. if h.DestConnectionID.Len() > 0 {
  136. publicFlagByte |= 0x08
  137. }
  138. if len(h.DiversificationNonce) > 0 {
  139. if len(h.DiversificationNonce) != 32 {
  140. return errors.New("invalid diversification nonce length")
  141. }
  142. publicFlagByte |= 0x04
  143. }
  144. switch h.PacketNumberLen {
  145. case protocol.PacketNumberLen1:
  146. publicFlagByte |= 0x00
  147. case protocol.PacketNumberLen2:
  148. publicFlagByte |= 0x10
  149. case protocol.PacketNumberLen4:
  150. publicFlagByte |= 0x20
  151. }
  152. b.WriteByte(publicFlagByte)
  153. if h.DestConnectionID.Len() > 0 {
  154. b.Write(h.DestConnectionID)
  155. }
  156. if h.VersionFlag && pers == protocol.PerspectiveClient {
  157. utils.BigEndian.WriteUint32(b, uint32(h.Version))
  158. }
  159. if len(h.DiversificationNonce) > 0 {
  160. b.Write(h.DiversificationNonce)
  161. }
  162. switch h.PacketNumberLen {
  163. case protocol.PacketNumberLen1:
  164. b.WriteByte(uint8(h.PacketNumber))
  165. case protocol.PacketNumberLen2:
  166. utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
  167. case protocol.PacketNumberLen4:
  168. utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
  169. case protocol.PacketNumberLen6:
  170. return errInvalidPacketNumberLen
  171. default:
  172. return errors.New("PublicHeader: PacketNumberLen not set")
  173. }
  174. return nil
  175. }
  176. // GetLength determines the length of the Header.
  177. func (h *Header) GetLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
  178. if !v.UsesIETFHeaderFormat() {
  179. return h.getPublicHeaderLength()
  180. }
  181. return h.getHeaderLength(v)
  182. }
  183. func (h *Header) getHeaderLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
  184. if h.IsLongHeader {
  185. length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen)
  186. if v.UsesLengthInHeader() {
  187. length += utils.VarIntLen(uint64(h.PayloadLen))
  188. }
  189. if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
  190. length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
  191. }
  192. if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
  193. length += protocol.ByteCount(len(h.DiversificationNonce))
  194. }
  195. return length, nil
  196. }
  197. length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
  198. if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
  199. return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
  200. }
  201. length += protocol.ByteCount(h.PacketNumberLen)
  202. return length, nil
  203. }
  204. // getPublicHeaderLength gets the length of the publicHeader in bytes.
  205. // It can only be called for regular packets.
  206. func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
  207. length := protocol.ByteCount(1) // 1 byte for public flags
  208. if h.PacketNumberLen == protocol.PacketNumberLen6 {
  209. return 0, errInvalidPacketNumberLen
  210. }
  211. if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
  212. return 0, errPacketNumberLenNotSet
  213. }
  214. length += protocol.ByteCount(h.PacketNumberLen)
  215. length += protocol.ByteCount(h.DestConnectionID.Len())
  216. // Version Number in packets sent by the client
  217. if h.VersionFlag {
  218. length += 4
  219. }
  220. length += protocol.ByteCount(len(h.DiversificationNonce))
  221. return length, nil
  222. }
  223. // Log logs the Header
  224. func (h *Header) Log(logger utils.Logger) {
  225. if h.IsPublicHeader {
  226. h.logPublicHeader(logger)
  227. } else {
  228. h.logHeader(logger)
  229. }
  230. }
  231. func (h *Header) logHeader(logger utils.Logger) {
  232. if h.IsLongHeader {
  233. if h.Version == 0 {
  234. logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
  235. } else {
  236. var token string
  237. if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
  238. if len(h.Token) == 0 {
  239. token = "Token: (empty), "
  240. } else {
  241. token = fmt.Sprintf("Token: %#x, ", h.Token)
  242. }
  243. }
  244. if h.Type == protocol.PacketTypeRetry {
  245. logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
  246. return
  247. }
  248. if h.Version == protocol.Version44 {
  249. var divNonce string
  250. if h.Type == protocol.PacketType0RTT {
  251. divNonce = fmt.Sprintf("Diversification Nonce: %#x, ", h.DiversificationNonce)
  252. }
  253. logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, divNonce, h.Version)
  254. return
  255. }
  256. logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
  257. }
  258. } else {
  259. logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
  260. }
  261. }
  262. func (h *Header) logPublicHeader(logger utils.Logger) {
  263. ver := "(unset)"
  264. if h.Version != 0 {
  265. ver = h.Version.String()
  266. }
  267. logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
  268. }
  269. func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
  270. dcil, err := encodeSingleConnIDLen(dest)
  271. if err != nil {
  272. return 0, err
  273. }
  274. scil, err := encodeSingleConnIDLen(src)
  275. if err != nil {
  276. return 0, err
  277. }
  278. return scil | dcil<<4, nil
  279. }
  280. func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
  281. len := id.Len()
  282. if len == 0 {
  283. return 0, nil
  284. }
  285. if len < 4 || len > 18 {
  286. return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
  287. }
  288. return byte(len - 3), nil
  289. }
  290. func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
  291. return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf)
  292. }
  293. func decodeSingleConnIDLen(enc uint8) int {
  294. if enc == 0 {
  295. return 0
  296. }
  297. return int(enc) + 3
  298. }