packet_handler_map.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package quic
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "sync"
  8. "time"
  9. "github.com/lucas-clemente/quic-go/internal/protocol"
  10. "github.com/lucas-clemente/quic-go/internal/utils"
  11. "github.com/lucas-clemente/quic-go/internal/wire"
  12. )
  13. type packetHandlerEntry struct {
  14. handler packetHandler
  15. resetToken *[16]byte
  16. }
  17. // The packetHandlerMap stores packetHandlers, identified by connection ID.
  18. // It is used:
  19. // * by the server to store sessions
  20. // * when multiplexing outgoing connections to store clients
  21. type packetHandlerMap struct {
  22. mutex sync.RWMutex
  23. conn net.PacketConn
  24. connIDLen int
  25. handlers map[string] /* string(ConnectionID)*/ packetHandlerEntry
  26. resetTokens map[[16]byte] /* stateless reset token */ packetHandler
  27. server unknownPacketHandler
  28. listening chan struct{} // is closed when listen returns
  29. closed bool
  30. deleteRetiredSessionsAfter time.Duration
  31. logger utils.Logger
  32. }
  33. var _ packetHandlerManager = &packetHandlerMap{}
  34. func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager {
  35. m := &packetHandlerMap{
  36. conn: conn,
  37. connIDLen: connIDLen,
  38. listening: make(chan struct{}),
  39. handlers: make(map[string]packetHandlerEntry),
  40. resetTokens: make(map[[16]byte]packetHandler),
  41. deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
  42. logger: logger,
  43. }
  44. go m.listen()
  45. return m
  46. }
  47. func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
  48. h.mutex.Lock()
  49. h.handlers[string(id)] = packetHandlerEntry{handler: handler}
  50. h.mutex.Unlock()
  51. }
  52. func (h *packetHandlerMap) AddWithResetToken(id protocol.ConnectionID, handler packetHandler, token [16]byte) {
  53. h.mutex.Lock()
  54. h.handlers[string(id)] = packetHandlerEntry{handler: handler, resetToken: &token}
  55. h.resetTokens[token] = handler
  56. h.mutex.Unlock()
  57. }
  58. func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
  59. h.removeByConnectionIDAsString(string(id))
  60. }
  61. func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
  62. h.mutex.Lock()
  63. if handlerEntry, ok := h.handlers[id]; ok {
  64. if token := handlerEntry.resetToken; token != nil {
  65. delete(h.resetTokens, *token)
  66. }
  67. delete(h.handlers, id)
  68. }
  69. h.mutex.Unlock()
  70. }
  71. func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
  72. h.retireByConnectionIDAsString(string(id))
  73. }
  74. func (h *packetHandlerMap) retireByConnectionIDAsString(id string) {
  75. time.AfterFunc(h.deleteRetiredSessionsAfter, func() {
  76. h.removeByConnectionIDAsString(id)
  77. })
  78. }
  79. func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
  80. h.mutex.Lock()
  81. h.server = s
  82. h.mutex.Unlock()
  83. }
  84. func (h *packetHandlerMap) CloseServer() {
  85. h.mutex.Lock()
  86. h.server = nil
  87. var wg sync.WaitGroup
  88. for id, handlerEntry := range h.handlers {
  89. handler := handlerEntry.handler
  90. if handler.GetPerspective() == protocol.PerspectiveServer {
  91. wg.Add(1)
  92. go func(id string, handler packetHandler) {
  93. // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
  94. _ = handler.Close()
  95. h.retireByConnectionIDAsString(id)
  96. wg.Done()
  97. }(id, handler)
  98. }
  99. }
  100. h.mutex.Unlock()
  101. wg.Wait()
  102. }
  103. // Close the underlying connection and wait until listen() has returned.
  104. func (h *packetHandlerMap) Close() error {
  105. if err := h.conn.Close(); err != nil {
  106. return err
  107. }
  108. <-h.listening // wait until listening returns
  109. return nil
  110. }
  111. func (h *packetHandlerMap) close(e error) error {
  112. h.mutex.Lock()
  113. if h.closed {
  114. h.mutex.Unlock()
  115. return nil
  116. }
  117. h.closed = true
  118. var wg sync.WaitGroup
  119. for _, handlerEntry := range h.handlers {
  120. wg.Add(1)
  121. go func(handlerEntry packetHandlerEntry) {
  122. handlerEntry.handler.destroy(e)
  123. wg.Done()
  124. }(handlerEntry)
  125. }
  126. if h.server != nil {
  127. h.server.closeWithError(e)
  128. }
  129. h.mutex.Unlock()
  130. wg.Wait()
  131. return getMultiplexer().RemoveConn(h.conn)
  132. }
  133. func (h *packetHandlerMap) listen() {
  134. defer close(h.listening)
  135. for {
  136. buffer := getPacketBuffer()
  137. data := buffer.Slice
  138. // The packet size should not exceed protocol.MaxReceivePacketSize bytes
  139. // If it does, we only read a truncated packet, which will then end up undecryptable
  140. n, addr, err := h.conn.ReadFrom(data)
  141. if err != nil {
  142. h.close(err)
  143. return
  144. }
  145. h.handlePacket(addr, buffer, data[:n])
  146. }
  147. }
  148. func (h *packetHandlerMap) handlePacket(
  149. addr net.Addr,
  150. buffer *packetBuffer,
  151. data []byte,
  152. ) {
  153. packets, err := h.parsePacket(addr, buffer, data)
  154. if err != nil {
  155. h.logger.Debugf("error parsing packets from %s: %s", addr, err)
  156. // This is just the error from parsing the last packet.
  157. // We still need to process the packets that were successfully parsed before.
  158. }
  159. if len(packets) == 0 {
  160. buffer.Release()
  161. return
  162. }
  163. h.handleParsedPackets(packets)
  164. }
  165. func (h *packetHandlerMap) parsePacket(
  166. addr net.Addr,
  167. buffer *packetBuffer,
  168. data []byte,
  169. ) ([]*receivedPacket, error) {
  170. rcvTime := time.Now()
  171. packets := make([]*receivedPacket, 0, 1)
  172. var counter int
  173. var lastConnID protocol.ConnectionID
  174. for len(data) > 0 {
  175. hdr, err := wire.ParseHeader(bytes.NewReader(data), h.connIDLen)
  176. // drop the packet if we can't parse the header
  177. if err != nil {
  178. return packets, fmt.Errorf("error parsing header: %s", err)
  179. }
  180. if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) {
  181. return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID)
  182. }
  183. lastConnID = hdr.DestConnectionID
  184. var rest []byte
  185. if hdr.IsLongHeader {
  186. if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
  187. return packets, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
  188. }
  189. packetLen := int(hdr.ParsedLen() + hdr.Length)
  190. rest = data[packetLen:]
  191. data = data[:packetLen]
  192. }
  193. if counter > 0 {
  194. buffer.Split()
  195. }
  196. counter++
  197. packets = append(packets, &receivedPacket{
  198. remoteAddr: addr,
  199. hdr: hdr,
  200. rcvTime: rcvTime,
  201. data: data,
  202. buffer: buffer,
  203. })
  204. // only log if this actually a coalesced packet
  205. if h.logger.Debug() && (counter > 1 || len(rest) > 0) {
  206. h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packets[counter-1].data), len(rest))
  207. }
  208. data = rest
  209. }
  210. return packets, nil
  211. }
  212. func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) {
  213. h.mutex.RLock()
  214. defer h.mutex.RUnlock()
  215. // coalesced packets all have the same destination connection ID
  216. handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)]
  217. for _, p := range packets {
  218. if handlerFound { // existing session
  219. handlerEntry.handler.handlePacket(p)
  220. continue
  221. }
  222. // No session found.
  223. // This might be a stateless reset.
  224. if !p.hdr.IsLongHeader {
  225. if len(p.data) >= protocol.MinStatelessResetSize {
  226. var token [16]byte
  227. copy(token[:], p.data[len(p.data)-16:])
  228. if sess, ok := h.resetTokens[token]; ok {
  229. sess.destroy(errors.New("received a stateless reset"))
  230. continue
  231. }
  232. }
  233. // TODO(#943): send a stateless reset
  234. h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
  235. break // a short header packet is always the last in a coalesced packet
  236. }
  237. if h.server == nil { // no server set
  238. h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
  239. continue
  240. }
  241. h.server.handlePacket(p)
  242. }
  243. }