packet_handler_map.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package quic
  2. import (
  3. "bytes"
  4. "fmt"
  5. "net"
  6. "sync"
  7. "time"
  8. "github.com/lucas-clemente/quic-go/internal/protocol"
  9. "github.com/lucas-clemente/quic-go/internal/utils"
  10. "github.com/lucas-clemente/quic-go/internal/wire"
  11. )
  12. // The packetHandlerMap stores packetHandlers, identified by connection ID.
  13. // It is used:
  14. // * by the server to store sessions
  15. // * when multiplexing outgoing connections to store clients
  16. type packetHandlerMap struct {
  17. mutex sync.RWMutex
  18. conn net.PacketConn
  19. connIDLen int
  20. handlers map[string] /* string(ConnectionID)*/ packetHandler
  21. server unknownPacketHandler
  22. closed bool
  23. deleteClosedSessionsAfter time.Duration
  24. logger utils.Logger
  25. }
  26. var _ packetHandlerManager = &packetHandlerMap{}
  27. func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager {
  28. m := &packetHandlerMap{
  29. conn: conn,
  30. connIDLen: connIDLen,
  31. handlers: make(map[string]packetHandler),
  32. deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
  33. logger: logger,
  34. }
  35. go m.listen()
  36. return m
  37. }
  38. func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
  39. h.mutex.Lock()
  40. h.handlers[string(id)] = handler
  41. h.mutex.Unlock()
  42. }
  43. func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
  44. h.removeByConnectionIDAsString(string(id))
  45. }
  46. func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
  47. h.mutex.Lock()
  48. h.handlers[id] = nil
  49. h.mutex.Unlock()
  50. time.AfterFunc(h.deleteClosedSessionsAfter, func() {
  51. h.mutex.Lock()
  52. delete(h.handlers, id)
  53. h.mutex.Unlock()
  54. })
  55. }
  56. func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
  57. h.mutex.Lock()
  58. h.server = s
  59. h.mutex.Unlock()
  60. }
  61. func (h *packetHandlerMap) CloseServer() {
  62. h.mutex.Lock()
  63. h.server = nil
  64. var wg sync.WaitGroup
  65. for id, handler := range h.handlers {
  66. if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer {
  67. wg.Add(1)
  68. go func(id string, handler packetHandler) {
  69. // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
  70. _ = handler.Close()
  71. h.removeByConnectionIDAsString(id)
  72. wg.Done()
  73. }(id, handler)
  74. }
  75. }
  76. h.mutex.Unlock()
  77. wg.Wait()
  78. }
  79. func (h *packetHandlerMap) close(e error) error {
  80. h.mutex.Lock()
  81. if h.closed {
  82. h.mutex.Unlock()
  83. return nil
  84. }
  85. h.closed = true
  86. var wg sync.WaitGroup
  87. for _, handler := range h.handlers {
  88. if handler != nil {
  89. wg.Add(1)
  90. go func(handler packetHandler) {
  91. handler.destroy(e)
  92. wg.Done()
  93. }(handler)
  94. }
  95. }
  96. if h.server != nil {
  97. h.server.closeWithError(e)
  98. }
  99. h.mutex.Unlock()
  100. wg.Wait()
  101. return nil
  102. }
  103. func (h *packetHandlerMap) listen() {
  104. for {
  105. data := *getPacketBuffer()
  106. data = data[:protocol.MaxReceivePacketSize]
  107. // The packet size should not exceed protocol.MaxReceivePacketSize bytes
  108. // If it does, we only read a truncated packet, which will then end up undecryptable
  109. n, addr, err := h.conn.ReadFrom(data)
  110. if err != nil {
  111. h.close(err)
  112. return
  113. }
  114. data = data[:n]
  115. if err := h.handlePacket(addr, data); err != nil {
  116. h.logger.Debugf("error handling packet from %s: %s", addr, err)
  117. }
  118. }
  119. }
  120. func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
  121. rcvTime := time.Now()
  122. r := bytes.NewReader(data)
  123. iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen)
  124. // drop the packet if we can't parse the header
  125. if err != nil {
  126. return fmt.Errorf("error parsing invariant header: %s", err)
  127. }
  128. h.mutex.RLock()
  129. handler, ok := h.handlers[string(iHdr.DestConnectionID)]
  130. server := h.server
  131. h.mutex.RUnlock()
  132. var sentBy protocol.Perspective
  133. var version protocol.VersionNumber
  134. var handlePacket func(*receivedPacket)
  135. if ok && handler == nil {
  136. // Late packet for closed session
  137. return nil
  138. }
  139. if !ok {
  140. if server == nil { // no server set
  141. return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
  142. }
  143. handlePacket = server.handlePacket
  144. sentBy = protocol.PerspectiveClient
  145. version = iHdr.Version
  146. } else {
  147. sentBy = handler.GetPerspective().Opposite()
  148. version = handler.GetVersion()
  149. handlePacket = handler.handlePacket
  150. }
  151. hdr, err := iHdr.Parse(r, sentBy, version)
  152. if err != nil {
  153. return fmt.Errorf("error parsing header: %s", err)
  154. }
  155. hdr.Raw = data[:len(data)-r.Len()]
  156. packetData := data[len(data)-r.Len():]
  157. if hdr.IsLongHeader && hdr.Version.UsesLengthInHeader() {
  158. if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
  159. return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
  160. }
  161. packetData = packetData[:int(hdr.PayloadLen)]
  162. // TODO(#1312): implement parsing of compound packets
  163. }
  164. handlePacket(&receivedPacket{
  165. remoteAddr: addr,
  166. header: hdr,
  167. data: packetData,
  168. rcvTime: rcvTime,
  169. })
  170. return nil
  171. }