Browse Source

also use the multiplexer for the server

Marten Seemann 9 months ago
parent
commit
ad5a3e2fa0

+ 13 - 0
client.go

@@ -544,9 +544,22 @@ func (c *client) Close() error {
 	return c.session.Close()
 }
 
+func (c *client) destroy(e error) {
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
+	if c.session == nil {
+		return
+	}
+	c.session.destroy(e)
+}
+
 func (c *client) GetVersion() protocol.VersionNumber {
 	c.mutex.Lock()
 	v := c.version
 	c.mutex.Unlock()
 	return v
 }
+
+func (c *client) GetPerspective() protocol.Perspective {
+	return protocol.PerspectiveClient
+}

+ 5 - 0
internal/protocol/perspective.go

@@ -9,6 +9,11 @@ const (
 	PerspectiveClient Perspective = 2
 )
 
+// Opposite returns the perspective of the peer
+func (p Perspective) Opposite() Perspective {
+	return 3 - p
+}
+
 func (p Perspective) String() string {
 	switch p {
 	case PerspectiveServer:

+ 5 - 0
internal/protocol/perspective_test.go

@@ -11,4 +11,9 @@ var _ = Describe("Perspective", func() {
 		Expect(PerspectiveServer.String()).To(Equal("Server"))
 		Expect(Perspective(0).String()).To(Equal("invalid perspective"))
 	})
+
+	It("returns the opposite", func() {
+		Expect(PerspectiveClient.Opposite()).To(Equal(PerspectiveServer))
+		Expect(PerspectiveServer.Opposite()).To(Equal(PerspectiveClient))
+	})
 })

+ 16 - 21
mock_packet_handler_manager_test.go

@@ -44,29 +44,14 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1)
 }
 
-// Close mocks base method
-func (m *MockPacketHandlerManager) Close() error {
-	ret := m.ctrl.Call(m, "Close")
-	ret0, _ := ret[0].(error)
-	return ret0
+// CloseServer mocks base method
+func (m *MockPacketHandlerManager) CloseServer() {
+	m.ctrl.Call(m, "CloseServer")
 }
 
-// Close indicates an expected call of Close
-func (mr *MockPacketHandlerManagerMockRecorder) Close() *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close))
-}
-
-// Get mocks base method
-func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
-	ret := m.ctrl.Call(m, "Get", arg0)
-	ret0, _ := ret[0].(packetHandler)
-	ret1, _ := ret[1].(bool)
-	return ret0, ret1
-}
-
-// Get indicates an expected call of Get
-func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
+// CloseServer indicates an expected call of CloseServer
+func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
 }
 
 // Remove mocks base method
@@ -78,3 +63,13 @@ func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
 func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock.Call {
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0)
 }
+
+// SetServer mocks base method
+func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) {
+	m.ctrl.Call(m, "SetServer", arg0)
+}
+
+// SetServer indicates an expected call of SetServer
+func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0)
+}

+ 91 - 0
mock_packet_handler_test.go

@@ -0,0 +1,91 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/lucas-clemente/quic-go (interfaces: PacketHandler)
+
+// Package quic is a generated GoMock package.
+package quic
+
+import (
+	reflect "reflect"
+
+	gomock "github.com/golang/mock/gomock"
+	protocol "github.com/lucas-clemente/quic-go/internal/protocol"
+)
+
+// MockPacketHandler is a mock of PacketHandler interface
+type MockPacketHandler struct {
+	ctrl     *gomock.Controller
+	recorder *MockPacketHandlerMockRecorder
+}
+
+// MockPacketHandlerMockRecorder is the mock recorder for MockPacketHandler
+type MockPacketHandlerMockRecorder struct {
+	mock *MockPacketHandler
+}
+
+// NewMockPacketHandler creates a new mock instance
+func NewMockPacketHandler(ctrl *gomock.Controller) *MockPacketHandler {
+	mock := &MockPacketHandler{ctrl: ctrl}
+	mock.recorder = &MockPacketHandlerMockRecorder{mock}
+	return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder {
+	return m.recorder
+}
+
+// Close mocks base method
+func (m *MockPacketHandler) Close() error {
+	ret := m.ctrl.Call(m, "Close")
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// Close indicates an expected call of Close
+func (mr *MockPacketHandlerMockRecorder) Close() *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandler)(nil).Close))
+}
+
+// GetPerspective mocks base method
+func (m *MockPacketHandler) GetPerspective() protocol.Perspective {
+	ret := m.ctrl.Call(m, "GetPerspective")
+	ret0, _ := ret[0].(protocol.Perspective)
+	return ret0
+}
+
+// GetPerspective indicates an expected call of GetPerspective
+func (mr *MockPacketHandlerMockRecorder) GetPerspective() *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPerspective", reflect.TypeOf((*MockPacketHandler)(nil).GetPerspective))
+}
+
+// GetVersion mocks base method
+func (m *MockPacketHandler) GetVersion() protocol.VersionNumber {
+	ret := m.ctrl.Call(m, "GetVersion")
+	ret0, _ := ret[0].(protocol.VersionNumber)
+	return ret0
+}
+
+// GetVersion indicates an expected call of GetVersion
+func (mr *MockPacketHandlerMockRecorder) GetVersion() *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockPacketHandler)(nil).GetVersion))
+}
+
+// destroy mocks base method
+func (m *MockPacketHandler) destroy(arg0 error) {
+	m.ctrl.Call(m, "destroy", arg0)
+}
+
+// destroy indicates an expected call of destroy
+func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0)
+}
+
+// handlePacket mocks base method
+func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) {
+	m.ctrl.Call(m, "handlePacket", arg0)
+}
+
+// handlePacket indicates an expected call of handlePacket
+func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0)
+}

+ 56 - 0
mock_unknown_packet_handler_test.go

@@ -0,0 +1,56 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/lucas-clemente/quic-go (interfaces: UnknownPacketHandler)
+
+// Package quic is a generated GoMock package.
+package quic
+
+import (
+	reflect "reflect"
+
+	gomock "github.com/golang/mock/gomock"
+)
+
+// MockUnknownPacketHandler is a mock of UnknownPacketHandler interface
+type MockUnknownPacketHandler struct {
+	ctrl     *gomock.Controller
+	recorder *MockUnknownPacketHandlerMockRecorder
+}
+
+// MockUnknownPacketHandlerMockRecorder is the mock recorder for MockUnknownPacketHandler
+type MockUnknownPacketHandlerMockRecorder struct {
+	mock *MockUnknownPacketHandler
+}
+
+// NewMockUnknownPacketHandler creates a new mock instance
+func NewMockUnknownPacketHandler(ctrl *gomock.Controller) *MockUnknownPacketHandler {
+	mock := &MockUnknownPacketHandler{ctrl: ctrl}
+	mock.recorder = &MockUnknownPacketHandlerMockRecorder{mock}
+	return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorder {
+	return m.recorder
+}
+
+// closeWithError mocks base method
+func (m *MockUnknownPacketHandler) closeWithError(arg0 error) error {
+	ret := m.ctrl.Call(m, "closeWithError", arg0)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// closeWithError indicates an expected call of closeWithError
+func (mr *MockUnknownPacketHandlerMockRecorder) closeWithError(arg0 interface{}) *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).closeWithError), arg0)
+}
+
+// handlePacket mocks base method
+func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) {
+	m.ctrl.Call(m, "handlePacket", arg0)
+}
+
+// handlePacket indicates an expected call of handlePacket
+func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0)
+}

+ 2 - 0
mockgen.go

@@ -13,6 +13,8 @@ package quic
 //go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
 //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
 //go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession QuicSession"
+//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler PacketHandler"
+//go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/lucas-clemente/quic-go unknownPacketHandler UnknownPacketHandler"
 //go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager PacketHandlerManager"
 //go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer Multiplexer"
 //go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"

+ 2 - 2
multiplexer.go

@@ -28,7 +28,7 @@ type connMultiplexer struct {
 	mutex sync.Mutex
 
 	conns                   map[net.PacketConn]connManager
-	newPacketHandlerManager func(net.PacketConn, int, utils.Logger, bool) packetHandlerManager // so it can be replaced in the tests
+	newPacketHandlerManager func(net.PacketConn, int, utils.Logger) packetHandlerManager // so it can be replaced in the tests
 
 	logger utils.Logger
 }
@@ -52,7 +52,7 @@ func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandle
 
 	p, ok := m.conns[c]
 	if !ok {
-		manager := m.newPacketHandlerManager(c, connIDLen, m.logger, true)
+		manager := m.newPacketHandlerManager(c, connIDLen, m.logger)
 		p = connManager{connIDLen: connIDLen, manager: manager}
 		m.conns[c] = p
 	}

+ 66 - 28
packet_handler_map.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"fmt"
 	"net"
-	"strings"
 	"sync"
 	"time"
 
@@ -24,6 +23,7 @@ type packetHandlerMap struct {
 	connIDLen int
 
 	handlers map[string] /* string(ConnectionID)*/ packetHandler
+	server   unknownPacketHandler
 	closed   bool
 
 	deleteClosedSessionsAfter time.Duration
@@ -33,8 +33,7 @@ type packetHandlerMap struct {
 
 var _ packetHandlerManager = &packetHandlerMap{}
 
-// TODO(#561): remove the listen flag
-func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger, listen bool) packetHandlerManager {
+func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager {
 	m := &packetHandlerMap{
 		conn:                      conn,
 		connIDLen:                 connIDLen,
@@ -42,19 +41,10 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger
 		deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
 		logger: logger,
 	}
-	if listen {
-		go m.listen()
-	}
+	go m.listen()
 	return m
 }
 
-func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
-	h.mutex.RLock()
-	sess, ok := h.handlers[string(id)]
-	h.mutex.RUnlock()
-	return sess, ok
-}
-
 func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
 	h.mutex.Lock()
 	h.handlers[string(id)] = handler
@@ -62,18 +52,47 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
 }
 
 func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
+	h.removeByConnectionIDAsString(string(id))
+}
+
+func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
 	h.mutex.Lock()
-	h.handlers[string(id)] = nil
+	h.handlers[id] = nil
 	h.mutex.Unlock()
 
 	time.AfterFunc(h.deleteClosedSessionsAfter, func() {
 		h.mutex.Lock()
-		delete(h.handlers, string(id))
+		delete(h.handlers, id)
 		h.mutex.Unlock()
 	})
 }
 
-func (h *packetHandlerMap) Close() error {
+func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
+	h.mutex.Lock()
+	h.server = s
+	h.mutex.Unlock()
+}
+
+func (h *packetHandlerMap) CloseServer() {
+	h.mutex.Lock()
+	h.server = nil
+	var wg sync.WaitGroup
+	for id, handler := range h.handlers {
+		if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer {
+			wg.Add(1)
+			go func(id string, handler packetHandler) {
+				// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
+				_ = handler.Close()
+				h.removeByConnectionIDAsString(id)
+				wg.Done()
+			}(id, handler)
+		}
+	}
+	h.mutex.Unlock()
+	wg.Wait()
+}
+
+func (h *packetHandlerMap) close(e error) error {
 	h.mutex.Lock()
 	if h.closed {
 		h.mutex.Unlock()
@@ -86,12 +105,15 @@ func (h *packetHandlerMap) Close() error {
 		if handler != nil {
 			wg.Add(1)
 			go func(handler packetHandler) {
-				// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
-				_ = handler.Close()
+				handler.destroy(e)
 				wg.Done()
 			}(handler)
 		}
 	}
+
+	if h.server != nil {
+		h.server.closeWithError(e)
+	}
 	h.mutex.Unlock()
 	wg.Wait()
 	return nil
@@ -105,9 +127,7 @@ func (h *packetHandlerMap) listen() {
 		// If it does, we only read a truncated packet, which will then end up undecryptable
 		n, addr, err := h.conn.ReadFrom(data)
 		if err != nil {
-			if !strings.HasSuffix(err.Error(), "use of closed network connection") {
-				h.Close()
-			}
+			h.close(err)
 			return
 		}
 		data = data[:n]
@@ -127,15 +147,33 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
 	if err != nil {
 		return fmt.Errorf("error parsing invariant header: %s", err)
 	}
-	handler, ok := h.Get(iHdr.DestConnectionID)
-	if !ok {
-		return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
-	}
-	if handler == nil {
+
+	h.mutex.RLock()
+	handler, ok := h.handlers[string(iHdr.DestConnectionID)]
+	server := h.server
+	h.mutex.RUnlock()
+
+	var sentBy protocol.Perspective
+	var version protocol.VersionNumber
+	var handlePacket func(*receivedPacket)
+	if ok && handler == nil {
 		// Late packet for closed session
 		return nil
 	}
-	hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, handler.GetVersion())
+	if !ok {
+		if server == nil { // no server set
+			return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
+		}
+		handlePacket = server.handlePacket
+		sentBy = protocol.PerspectiveClient
+		version = iHdr.Version
+	} else {
+		sentBy = handler.GetPerspective().Opposite()
+		version = handler.GetVersion()
+		handlePacket = handler.handlePacket
+	}
+
+	hdr, err := iHdr.Parse(r, sentBy, version)
 	if err != nil {
 		return fmt.Errorf("error parsing header: %s", err)
 	}
@@ -150,7 +188,7 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
 		// TODO(#1312): implement parsing of compound packets
 	}
 
-	handler.handlePacket(&receivedPacket{
+	handlePacket(&receivedPacket{
 		remoteAddr: addr,
 		header:     hdr,
 		data:       packetData,

+ 76 - 54
packet_handler_map_test.go

@@ -2,6 +2,7 @@ package quic
 
 import (
 	"bytes"
+	"errors"
 	"time"
 
 	"github.com/golang/mock/gomock"
@@ -18,66 +19,38 @@ var _ = Describe("Packet Handler Map", func() {
 		conn    *mockPacketConn
 	)
 
+	getPacket := func(connID protocol.ConnectionID) []byte {
+		buf := &bytes.Buffer{}
+		err := (&wire.Header{
+			DestConnectionID: connID,
+			PacketNumberLen:  protocol.PacketNumberLen1,
+		}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
+		Expect(err).ToNot(HaveOccurred())
+		return buf.Bytes()
+	}
+
 	BeforeEach(func() {
 		conn = newMockPacketConn()
-		handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger, true).(*packetHandlerMap)
-	})
-
-	It("adds and gets", func() {
-		connID := protocol.ConnectionID{1, 2, 3, 4, 5}
-		sess := &mockSession{}
-		handler.Add(connID, sess)
-		session, ok := handler.Get(connID)
-		Expect(ok).To(BeTrue())
-		Expect(session).To(Equal(sess))
-	})
-
-	It("deletes", func() {
-		connID := protocol.ConnectionID{1, 2, 3, 4, 5}
-		handler.Add(connID, &mockSession{})
-		handler.Remove(connID)
-		session, ok := handler.Get(connID)
-		Expect(ok).To(BeTrue())
-		Expect(session).To(BeNil())
-	})
-
-	It("deletes nil session entries after a wait time", func() {
-		handler.deleteClosedSessionsAfter = 25 * time.Millisecond
-		connID := protocol.ConnectionID{1, 2, 3, 4, 5}
-		handler.Add(connID, &mockSession{})
-		handler.Remove(connID)
-		Eventually(func() bool {
-			_, ok := handler.Get(connID)
-			return ok
-		}).Should(BeFalse())
+		handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap)
 	})
 
 	It("closes", func() {
-		sess1 := NewMockQuicSession(mockCtrl)
-		sess1.EXPECT().Close()
-		sess2 := NewMockQuicSession(mockCtrl)
-		sess2.EXPECT().Close()
+		testErr := errors.New("test error	")
+		sess1 := NewMockPacketHandler(mockCtrl)
+		sess1.EXPECT().destroy(testErr)
+		sess2 := NewMockPacketHandler(mockCtrl)
+		sess2.EXPECT().destroy(testErr)
 		handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1)
 		handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2)
-		handler.Close()
+		handler.close(testErr)
 	})
 
 	Context("handling packets", func() {
-		getPacket := func(connID protocol.ConnectionID) []byte {
-			buf := &bytes.Buffer{}
-			err := (&wire.Header{
-				DestConnectionID: connID,
-				PacketNumberLen:  protocol.PacketNumberLen1,
-			}).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever)
-			Expect(err).ToNot(HaveOccurred())
-			return buf.Bytes()
-		}
-
 		It("handles packets for different packet handlers on the same packet conn", func() {
 			connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
 			connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
-			packetHandler1 := NewMockQuicSession(mockCtrl)
-			packetHandler2 := NewMockQuicSession(mockCtrl)
+			packetHandler1 := NewMockPacketHandler(mockCtrl)
+			packetHandler2 := NewMockPacketHandler(mockCtrl)
 			handledPacket1 := make(chan struct{})
 			handledPacket2 := make(chan struct{})
 			packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
@@ -85,11 +58,13 @@ var _ = Describe("Packet Handler Map", func() {
 				close(handledPacket1)
 			})
 			packetHandler1.EXPECT().GetVersion()
+			packetHandler1.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
 			packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
 				Expect(p.header.DestConnectionID).To(Equal(connID2))
 				close(handledPacket2)
 			})
 			packetHandler2.EXPECT().GetVersion()
+			packetHandler2.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
 			handler.Add(connID1, packetHandler1)
 			handler.Add(connID2, packetHandler2)
 
@@ -99,8 +74,8 @@ var _ = Describe("Packet Handler Map", func() {
 			Eventually(handledPacket2).Should(BeClosed())
 
 			// makes the listen go routine return
-			packetHandler1.EXPECT().Close().AnyTimes()
-			packetHandler2.EXPECT().Close().AnyTimes()
+			packetHandler1.EXPECT().destroy(gomock.Any()).AnyTimes()
+			packetHandler2.EXPECT().destroy(gomock.Any()).AnyTimes()
 			close(conn.dataToRead)
 		})
 
@@ -110,10 +85,20 @@ var _ = Describe("Packet Handler Map", func() {
 			Expect(err.Error()).To(ContainSubstring("error parsing invariant header:"))
 		})
 
+		It("deletes nil session entries after a wait time", func() {
+			handler.deleteClosedSessionsAfter = 10 * time.Millisecond
+			connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
+			handler.Add(connID, NewMockPacketHandler(mockCtrl))
+			handler.Remove(connID)
+			Eventually(func() error {
+				return handler.handlePacket(nil, getPacket(connID))
+			}).Should(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
+		})
+
 		It("ignores packets arriving late for closed sessions", func() {
 			handler.deleteClosedSessionsAfter = time.Hour
 			connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
-			handler.Add(connID, NewMockQuicSession(mockCtrl))
+			handler.Add(connID, NewMockPacketHandler(mockCtrl))
 			handler.Remove(connID)
 			err := handler.handlePacket(nil, getPacket(connID))
 			Expect(err).ToNot(HaveOccurred())
@@ -127,8 +112,9 @@ var _ = Describe("Packet Handler Map", func() {
 
 		It("errors on packets that are smaller than the Payload Length in the packet header", func() {
 			connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
-			packetHandler := NewMockQuicSession(mockCtrl)
+			packetHandler := NewMockPacketHandler(mockCtrl)
 			packetHandler.EXPECT().GetVersion().Return(versionIETFFrames)
+			packetHandler.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
 			handler.Add(connID, packetHandler)
 			hdr := &wire.Header{
 				IsLongHeader:     true,
@@ -148,8 +134,9 @@ var _ = Describe("Packet Handler Map", func() {
 
 		It("cuts packets at the Payload Length", func() {
 			connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
-			packetHandler := NewMockQuicSession(mockCtrl)
+			packetHandler := NewMockPacketHandler(mockCtrl)
 			packetHandler.EXPECT().GetVersion().Return(versionIETFFrames)
+			packetHandler.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
 			handler.Add(connID, packetHandler)
 			packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
 				Expect(p.data).To(HaveLen(456))
@@ -172,8 +159,9 @@ var _ = Describe("Packet Handler Map", func() {
 
 		It("closes the packet handlers when reading from the conn fails", func() {
 			done := make(chan struct{})
-			packetHandler := NewMockQuicSession(mockCtrl)
-			packetHandler.EXPECT().Close().Do(func() {
+			packetHandler := NewMockPacketHandler(mockCtrl)
+			packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
+				Expect(e).To(HaveOccurred())
 				close(done)
 			})
 			handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)
@@ -181,4 +169,38 @@ var _ = Describe("Packet Handler Map", func() {
 			Eventually(done).Should(BeClosed())
 		})
 	})
+
+	Context("running a server", func() {
+		It("adds a server", func() {
+			connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
+			p := getPacket(connID)
+			server := NewMockUnknownPacketHandler(mockCtrl)
+			server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
+				Expect(p.header.DestConnectionID).To(Equal(connID))
+			})
+			handler.SetServer(server)
+			Expect(handler.handlePacket(nil, p)).To(Succeed())
+		})
+
+		It("closes all server sessions", func() {
+			clientSess := NewMockPacketHandler(mockCtrl)
+			clientSess.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
+			serverSess := NewMockPacketHandler(mockCtrl)
+			serverSess.EXPECT().GetPerspective().Return(protocol.PerspectiveServer)
+			serverSess.EXPECT().Close()
+
+			handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess)
+			handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess)
+			handler.CloseServer()
+		})
+
+		It("stops handling packets with unknown connection IDs after the server is closed", func() {
+			connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
+			p := getPacket(connID)
+			server := NewMockUnknownPacketHandler(mockCtrl)
+			handler.SetServer(server)
+			handler.CloseServer()
+			Expect(handler.handlePacket(nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788"))
+		})
+	})
 })

+ 67 - 167
server.go

@@ -1,7 +1,6 @@
 package quic
 
 import (
-	"bytes"
 	"crypto/tls"
 	"errors"
 	"fmt"
@@ -14,21 +13,27 @@ import (
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 	"github.com/lucas-clemente/quic-go/internal/utils"
 	"github.com/lucas-clemente/quic-go/internal/wire"
-	"github.com/lucas-clemente/quic-go/qerr"
 )
 
 // packetHandler handles packets
 type packetHandler interface {
 	handlePacket(*receivedPacket)
-	GetVersion() protocol.VersionNumber
 	io.Closer
+	destroy(error)
+	GetVersion() protocol.VersionNumber
+	GetPerspective() protocol.Perspective
+}
+
+type unknownPacketHandler interface {
+	handlePacket(*receivedPacket)
+	closeWithError(error) error
 }
 
 type packetHandlerManager interface {
 	Add(protocol.ConnectionID, packetHandler)
-	Get(protocol.ConnectionID) (packetHandler, bool)
+	SetServer(unknownPacketHandler)
 	Remove(protocol.ConnectionID)
-	io.Closer
+	CloseServer()
 }
 
 type quicSession interface {
@@ -84,6 +89,7 @@ type server struct {
 }
 
 var _ Listener = &server{}
+var _ unknownPacketHandler = &server{}
 
 // ListenAddr creates a QUIC server listening on a given address.
 // The tls.Config must not be nil, the quic.Config may be nil.
@@ -125,7 +131,10 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
 		}
 	}
 
-	logger := utils.DefaultLogger.WithPrefix("server")
+	sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
+	if err != nil {
+		return nil, err
+	}
 	s := &server{
 		conn:           conn,
 		tlsConf:        tlsConf,
@@ -133,11 +142,11 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
 		certChain:      certChain,
 		scfg:           scfg,
 		newSession:     newSession,
-		sessionHandler: newPacketHandlerMap(conn, config.ConnectionIDLength, logger, false),
+		sessionHandler: sessionHandler,
 		sessionQueue:   make(chan Session, 5),
 		errorChan:      make(chan struct{}),
 		supportsTLS:    supportsTLS,
-		logger:         logger,
+		logger:         utils.DefaultLogger.WithPrefix("server"),
 	}
 	s.setup()
 	if supportsTLS {
@@ -145,7 +154,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
 			return nil, err
 		}
 	}
-	go s.serve()
+	sessionHandler.SetServer(s)
 	s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
 	return s, nil
 }
@@ -176,7 +185,8 @@ func (s *server) setupTLS() error {
 			case tlsSession := <-sessionChan:
 				// The connection ID is a randomly chosen 8 byte value.
 				// It is safe to assume that it doesn't collide with other randomly chosen values.
-				s.sessionHandler.Add(tlsSession.connID, tlsSession.sess)
+				serverSession := newServerSession(tlsSession.sess, s.config, s.logger)
+				s.sessionHandler.Add(tlsSession.connID, serverSession)
 			}
 		}
 	}()
@@ -263,27 +273,6 @@ func populateServerConfig(config *Config) *Config {
 	}
 }
 
-// serve listens on an existing PacketConn
-func (s *server) serve() {
-	for {
-		data := *getPacketBuffer()
-		data = data[:protocol.MaxReceivePacketSize]
-		// The packet size should not exceed protocol.MaxReceivePacketSize bytes
-		// If it does, we only read a truncated packet, which will then end up undecryptable
-		n, remoteAddr, err := s.conn.ReadFrom(data)
-		if err != nil {
-			s.serverError = err
-			close(s.errorChan)
-			_ = s.Close()
-			return
-		}
-		data = data[:n]
-		if err := s.handlePacket(remoteAddr, data); err != nil {
-			s.logger.Errorf("error handling packet: %s", err.Error())
-		}
-	}
-}
-
 // Accept returns newly openend sessions
 func (s *server) Accept() (Session, error) {
 	var sess Session
@@ -297,10 +286,13 @@ func (s *server) Accept() (Session, error) {
 
 // Close the server
 func (s *server) Close() error {
-	s.sessionHandler.Close()
-	err := s.conn.Close()
-	<-s.errorChan // wait for serve() to return
-	return err
+	s.sessionHandler.CloseServer()
+	// TODO: close the conn if this server was started with ListenAddr() (but not with Listen(net.PacketConn))
+	if s.serverError == nil {
+		s.serverError = errors.New("server closed")
+	}
+	close(s.errorChan)
+	return nil
 }
 
 // Addr returns the server's network address
@@ -308,157 +300,65 @@ func (s *server) Addr() net.Addr {
 	return s.conn.LocalAddr()
 }
 
-func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error {
-	rcvTime := time.Now()
-
-	r := bytes.NewReader(packet)
-	iHdr, err := wire.ParseInvariantHeader(r, s.config.ConnectionIDLength)
-	if err != nil {
-		return qerr.Error(qerr.InvalidPacketHeader, err.Error())
-	}
-	session, sessionKnown := s.sessionHandler.Get(iHdr.DestConnectionID)
-	if sessionKnown && session == nil {
-		// Late packet for closed session
-		return nil
-	}
-	version := protocol.VersionUnknown
-	if sessionKnown {
-		version = session.GetVersion()
-	}
-	hdr, err := iHdr.Parse(r, protocol.PerspectiveClient, version)
-	if err != nil {
-		return qerr.Error(qerr.InvalidPacketHeader, err.Error())
-	}
-	hdr.Raw = packet[:len(packet)-r.Len()]
-	packetData := packet[len(packet)-r.Len():]
-
-	if hdr.IsPublicHeader {
-		return s.handleGQUICPacket(session, hdr, packetData, remoteAddr, rcvTime)
-	}
-	return s.handleIETFQUICPacket(session, hdr, packetData, remoteAddr, rcvTime)
+func (s *server) closeWithError(e error) error {
+	s.serverError = e
+	return s.Close()
 }
 
-func (s *server) handleIETFQUICPacket(
-	session packetHandler,
-	hdr *wire.Header,
-	packetData []byte,
-	remoteAddr net.Addr,
-	rcvTime time.Time,
-) error {
-	if hdr.IsLongHeader {
-		if !s.supportsTLS {
-			return errors.New("Received an IETF QUIC Long Header")
-		}
-		if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
-			return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
-		}
-		packetData = packetData[:int(hdr.PayloadLen)]
-		// TODO(#1312): implement parsing of compound packets
-
-		switch hdr.Type {
-		case protocol.PacketTypeInitial:
-			go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
-			return nil
-		case protocol.PacketTypeHandshake:
-			// nothing to do here. Packet will be passed to the session.
-		default:
-			// Note that this also drops 0-RTT packets.
-			return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
-		}
-	}
-
-	if session == nil {
-		s.logger.Debugf("Received %s packet for unknown connection %s.", hdr.Type, hdr.DestConnectionID)
-		return nil
+func (s *server) handlePacket(p *receivedPacket) {
+	if err := s.handlePacketImpl(p); err != nil {
+		s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
 	}
-
-	session.handlePacket(&receivedPacket{
-		remoteAddr: remoteAddr,
-		header:     hdr,
-		data:       packetData,
-		rcvTime:    rcvTime,
-	})
-	return nil
 }
 
-func (s *server) handleGQUICPacket(
-	session packetHandler,
-	hdr *wire.Header,
-	packetData []byte,
-	remoteAddr net.Addr,
-	rcvTime time.Time,
-) error {
-	// ignore all Public Reset packets
-	if hdr.ResetFlag {
-		s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID)
+func (s *server) handlePacketImpl(p *receivedPacket) error {
+	hdr := p.header
+	version := hdr.Version
+
+	if hdr.Type == protocol.PacketTypeInitial {
+		go s.serverTLS.HandleInitial(p.remoteAddr, hdr, p.data)
 		return nil
 	}
 
-	sessionKnown := session != nil
-
-	// If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset
-	// This should only happen after a server restart, when we still receive packets for connections that we lost the state for.
-	if !sessionKnown && !hdr.VersionFlag {
-		_, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), remoteAddr)
+	if !hdr.VersionFlag {
+		_, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr)
 		return err
 	}
 
-	// a session is only created once the client sent a supported version
-	// if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated
-	// it is safe to drop it
-	if sessionKnown && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
-		return nil
+	// This is (potentially) a Client Hello.
+	// Make sure it has the minimum required size before spending any more ressources on it.
+	if len(p.data) < protocol.MinClientHelloSize {
+		return errors.New("dropping small packet for unknown connection")
 	}
 
 	// send a Version Negotiation Packet if the client is speaking a different protocol version
 	// since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet
-	if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
-		// drop packets that are too small to be valid first packets
-		if len(packetData) < protocol.MinClientHelloSize {
-			return errors.New("dropping small packet with unknown version")
-		}
-		s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version)
-		_, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), remoteAddr)
+	if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, version) {
+		s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", version)
+		_, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), p.remoteAddr)
 		return err
 	}
 
-	if !sessionKnown {
-		// This is (potentially) a Client Hello.
-		// Make sure it has the minimum required size before spending any more ressources on it.
-		if len(packetData) < protocol.MinClientHelloSize {
-			return errors.New("dropping small packet for unknown connection")
-		}
-
-		version := hdr.Version
-		if !protocol.IsSupportedVersion(s.config.Versions, version) {
-			return errors.New("Server BUG: negotiated version not supported")
-		}
-
-		s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, remoteAddr)
-		sess, err := s.newSession(
-			&conn{pconn: s.conn, currentAddr: remoteAddr},
-			s.sessionRunner,
-			version,
-			hdr.DestConnectionID,
-			s.scfg,
-			s.tlsConf,
-			s.config,
-			s.logger,
-		)
-		if err != nil {
-			return err
-		}
-		s.sessionHandler.Add(hdr.DestConnectionID, sess)
-
-		go sess.run()
-		session = sess
+	if !protocol.IsSupportedVersion(s.config.Versions, version) {
+		return errors.New("Server BUG: negotiated version not supported")
 	}
 
-	session.handlePacket(&receivedPacket{
-		remoteAddr: remoteAddr,
-		header:     hdr,
-		data:       packetData,
-		rcvTime:    rcvTime,
-	})
+	s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, p.remoteAddr)
+	sess, err := s.newSession(
+		&conn{pconn: s.conn, currentAddr: p.remoteAddr},
+		s.sessionRunner,
+		version,
+		hdr.DestConnectionID,
+		s.scfg,
+		s.tlsConf,
+		s.config,
+		s.logger,
+	)
+	if err != nil {
+		return err
+	}
+	s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger))
+	go sess.run()
+	sess.handlePacket(p)
 	return nil
 }

+ 63 - 0
server_session.go

@@ -0,0 +1,63 @@
+package quic
+
+import (
+	"fmt"
+
+	"github.com/lucas-clemente/quic-go/internal/protocol"
+	"github.com/lucas-clemente/quic-go/internal/utils"
+)
+
+type serverSession struct {
+	quicSession
+
+	config *Config
+
+	logger utils.Logger
+}
+
+var _ packetHandler = &serverSession{}
+
+func newServerSession(sess quicSession, config *Config, logger utils.Logger) packetHandler {
+	return &serverSession{
+		quicSession: sess,
+		config:      config,
+		logger:      logger,
+	}
+}
+
+func (s *serverSession) handlePacket(p *receivedPacket) {
+	if err := s.handlePacketImpl(p); err != nil {
+		s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
+	}
+}
+
+func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
+	hdr := p.header
+	// ignore all Public Reset packets
+	if hdr.ResetFlag {
+		return fmt.Errorf("Received unexpected Public Reset for connection %s", hdr.DestConnectionID)
+	}
+
+	// Probably an old packet that was sent by the client before the version was negotiated.
+	// It is safe to drop it.
+	if (hdr.VersionFlag || hdr.IsLongHeader) && hdr.Version != s.quicSession.GetVersion() {
+		return nil
+	}
+
+	if hdr.IsLongHeader {
+		switch hdr.Type {
+		case protocol.PacketTypeHandshake:
+			// nothing to do here. Packet will be passed to the session.
+		default:
+			// Note that this also drops 0-RTT packets.
+			return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
+		}
+	}
+
+	s.quicSession.handlePacket(p)
+	return nil
+}
+
+func (s *serverSession) GetPerspective() protocol.Perspective {
+	return protocol.PerspectiveServer
+}

+ 101 - 0
server_session_test.go

@@ -0,0 +1,101 @@
+package quic
+
+import (
+	"github.com/lucas-clemente/quic-go/internal/protocol"
+	"github.com/lucas-clemente/quic-go/internal/utils"
+	"github.com/lucas-clemente/quic-go/internal/wire"
+
+	. "github.com/onsi/ginkgo"
+	. "github.com/onsi/gomega"
+)
+
+var _ = Describe("Server Session", func() {
+	var (
+		qsess *MockQuicSession
+		sess  *serverSession
+	)
+
+	BeforeEach(func() {
+		qsess = NewMockQuicSession(mockCtrl)
+		sess = newServerSession(qsess, &Config{}, utils.DefaultLogger).(*serverSession)
+	})
+
+	It("handles packets", func() {
+		p := &receivedPacket{
+			header: &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}},
+		}
+		qsess.EXPECT().handlePacket(p)
+		sess.handlePacket(p)
+	})
+
+	It("ignores Public Resets", func() {
+		p := &receivedPacket{
+			header: &wire.Header{
+				ResetFlag:        true,
+				DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
+			},
+		}
+		err := sess.handlePacketImpl(p)
+		Expect(err).To(MatchError("Received unexpected Public Reset for connection 0xdeadbeef"))
+	})
+
+	It("ignores delayed packets with mismatching versions, for gQUIC", func() {
+		qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
+		// don't EXPECT any calls to handlePacket()
+		p := &receivedPacket{
+			header: &wire.Header{
+				VersionFlag:      true,
+				Version:          protocol.VersionNumber(123),
+				DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
+			},
+		}
+		err := sess.handlePacketImpl(p)
+		Expect(err).ToNot(HaveOccurred())
+	})
+
+	It("ignores delayed packets with mismatching versions, for IETF QUIC", func() {
+		qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
+		// don't EXPECT any calls to handlePacket()
+		p := &receivedPacket{
+			header: &wire.Header{
+				IsLongHeader:     true,
+				Version:          protocol.VersionNumber(123),
+				DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
+			},
+		}
+		err := sess.handlePacketImpl(p)
+		Expect(err).ToNot(HaveOccurred())
+	})
+
+	It("ignores packets with the wrong Long Header type", func() {
+		qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
+		p := &receivedPacket{
+			header: &wire.Header{
+				IsLongHeader:     true,
+				Type:             protocol.PacketType0RTT,
+				Version:          protocol.VersionNumber(100),
+				DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
+			},
+		}
+		err := sess.handlePacketImpl(p)
+		Expect(err).To(MatchError("Received unsupported packet type: 0-RTT Protected"))
+	})
+
+	It("passes on Handshake packets", func() {
+		p := &receivedPacket{
+			header: &wire.Header{
+				IsLongHeader:     true,
+				Type:             protocol.PacketTypeHandshake,
+				Version:          protocol.VersionNumber(100),
+				DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
+			},
+		}
+		qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
+		qsess.EXPECT().handlePacket(p)
+		Expect(sess.handlePacketImpl(p)).To(Succeed())
+	})
+
+	It("has the right perspective", func() {
+		Expect(sess.GetPerspective()).To(Equal(protocol.PerspectiveServer))
+	})
+})

+ 61 - 233
server_test.go

@@ -14,7 +14,6 @@ import (
 	"github.com/lucas-clemente/quic-go/internal/testdata"
 	"github.com/lucas-clemente/quic-go/internal/utils"
 	"github.com/lucas-clemente/quic-go/internal/wire"
-	"github.com/lucas-clemente/quic-go/qerr"
 
 	. "github.com/onsi/ginkgo"
 	. "github.com/onsi/gomega"
@@ -27,6 +26,8 @@ type mockSession struct {
 	runner sessionRunner
 }
 
+func (s *mockSession) GetPerspective() protocol.Perspective { panic("not implemented") }
+
 var _ = Describe("Server", func() {
 	var (
 		conn    *mockPacketConn
@@ -89,7 +90,7 @@ var _ = Describe("Server", func() {
 	Context("with mock session", func() {
 		var (
 			serv           *server
-			firstPacket    []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID)
+			firstPacket    *receivedPacket
 			connID         = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
 			sessions       = make([]*MockQuicSession, 0)
 			sessionHandler *MockPacketHandlerManager
@@ -126,9 +127,16 @@ var _ = Describe("Server", func() {
 			serv.setup()
 			b := &bytes.Buffer{}
 			utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]))
-			firstPacket = []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
-			firstPacket = append(append(firstPacket, b.Bytes()...), 0x01)
-			firstPacket = append(firstPacket, bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)...) // add padding
+			firstPacket = &receivedPacket{
+				header: &wire.Header{
+					VersionFlag:      true,
+					Version:          serv.config.Versions[0],
+					DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6},
+					PacketNumber:     1,
+				},
+				data:    bytes.Repeat([]byte{0}, protocol.MinClientHelloSize),
+				rcvTime: time.Now(),
+			}
 		})
 
 		AfterEach(func() {
@@ -150,12 +158,10 @@ var _ = Describe("Server", func() {
 			s.EXPECT().run().Do(func() { close(run) })
 			sessions = append(sessions, s)
 
-			sessionHandler.EXPECT().Get(connID)
-			sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
-				Expect(sess.(*mockSession).connID).To(Equal(connID))
+			sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(cid protocol.ConnectionID, _ packetHandler) {
+				Expect(cid).To(Equal(connID))
 			})
-			err := serv.handlePacket(nil, firstPacket)
-			Expect(err).ToNot(HaveOccurred())
+			Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
 			Eventually(run).Should(BeClosed())
 		})
 
@@ -165,7 +171,8 @@ var _ = Describe("Server", func() {
 			err := serv.setupTLS()
 			Expect(err).ToNot(HaveOccurred())
 			added := make(chan struct{})
-			sessionHandler.EXPECT().Add(connID, sess).Do(func(protocol.ConnectionID, packetHandler) {
+			sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, ph packetHandler) {
+				Expect(ph.GetPerspective()).To(Equal(protocol.PerspectiveServer))
 				close(added)
 			})
 			serv.serverTLS.sessionChan <- tlsSession{
@@ -184,17 +191,15 @@ var _ = Describe("Server", func() {
 			done := make(chan struct{})
 			go func() {
 				defer GinkgoRecover()
-				sess, err := serv.Accept()
+				_, err := serv.Accept()
 				Expect(err).ToNot(HaveOccurred())
-				Expect(sess.(*mockSession).connID).To(Equal(connID))
 				close(done)
 			}()
-			sessionHandler.EXPECT().Get(connID)
 			sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
 				Consistently(done).ShouldNot(BeClosed())
-				sess.(*mockSession).runner.onHandshakeComplete(sess.(Session))
+				sess.(*serverSession).quicSession.(*mockSession).runner.onHandshakeComplete(sess.(Session))
 			})
-			err := serv.handlePacket(nil, firstPacket)
+			err := serv.handlePacketImpl(firstPacket)
 			Expect(err).ToNot(HaveOccurred())
 			Eventually(done).Should(BeClosed())
 			Eventually(run).Should(BeClosed())
@@ -212,45 +217,20 @@ var _ = Describe("Server", func() {
 				serv.Accept()
 				close(done)
 			}()
-			sessionHandler.EXPECT().Get(connID)
-			sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
+			sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(protocol.ConnectionID, packetHandler) {
 				run <- errors.New("handshake error")
 			})
-			err := serv.handlePacket(nil, firstPacket)
-			Expect(err).ToNot(HaveOccurred())
+			Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
 			Consistently(done).ShouldNot(BeClosed())
+
 			// make the go routine return
-			sessionHandler.EXPECT().Close()
 			close(serv.errorChan)
-			serv.Close()
 			Eventually(done).Should(BeClosed())
 		})
 
-		It("assigns packets to existing sessions", func() {
-			sess := NewMockQuicSession(mockCtrl)
-			sess.EXPECT().handlePacket(gomock.Any())
-			sess.EXPECT().GetVersion()
-
-			sessionHandler.EXPECT().Get(connID).Return(sess, true)
-			err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01})
-			Expect(err).ToNot(HaveOccurred())
-		})
-
-		It("closes the sessionHandler and the connection when Close is called", func() {
-			go func() {
-				defer GinkgoRecover()
-				serv.serve()
-			}()
-			// close the server
-			sessionHandler.EXPECT().Close().AnyTimes()
+		It("closes the sessionHandler when Close is called", func() {
+			sessionHandler.EXPECT().CloseServer()
 			Expect(serv.Close()).To(Succeed())
-			Expect(conn.closed).To(BeTrue())
-		})
-
-		It("ignores packets for closed sessions", func() {
-			sessionHandler.EXPECT().Get(connID).Return(nil, true)
-			err := serv.handlePacket(nil, firstPacket)
-			Expect(err).ToNot(HaveOccurred())
 		})
 
 		It("works if no quic.Config is given", func(done Done) {
@@ -264,163 +244,56 @@ var _ = Describe("Server", func() {
 			ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config)
 			Expect(err).ToNot(HaveOccurred())
 
-			var returned bool
+			done := make(chan struct{})
 			go func() {
 				defer GinkgoRecover()
-				_, err := ln.Accept()
-				Expect(err).To(HaveOccurred())
-				Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
-				returned = true
+				ln.Accept()
+				close(done)
 			}()
 			ln.Close()
-			Eventually(func() bool { return returned }).Should(BeTrue())
+			Eventually(done).Should(BeClosed())
 		})
 
-		It("errors when encountering a connection error", func() {
-			testErr := errors.New("connection error")
-			conn.readErr = testErr
-			sessionHandler.EXPECT().Close()
+		It("returns Accept when it is closed", func() {
 			done := make(chan struct{})
 			go func() {
 				defer GinkgoRecover()
-				serv.serve()
+				_, err := serv.Accept()
+				Expect(err).To(MatchError("server closed"))
 				close(done)
 			}()
-			_, err := serv.Accept()
-			Expect(err).To(MatchError(testErr))
+			sessionHandler.EXPECT().CloseServer()
+			Expect(serv.Close()).To(Succeed())
 			Eventually(done).Should(BeClosed())
 		})
 
-		It("ignores delayed packets with mismatching versions", func() {
-			sess := NewMockQuicSession(mockCtrl)
-			sess.EXPECT().GetVersion()
-			// don't EXPECT any handlePacket() calls to this session
-			sessionHandler.EXPECT().Get(connID).Return(sess, true)
-
-			b := &bytes.Buffer{}
-			// add an unsupported version
-			data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
-			utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1))
-			data = append(append(data, b.Bytes()...), 0x01)
-			err := serv.handlePacket(nil, data)
-			Expect(err).ToNot(HaveOccurred())
-			// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
-			Expect(conn.dataWritten.Bytes()).To(BeEmpty())
-		})
-
-		It("errors on invalid public header", func() {
-			err := serv.handlePacket(nil, nil)
-			Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
-		})
-
-		It("errors on packets that are smaller than the Payload Length in the packet header", func() {
-			sess := NewMockQuicSession(mockCtrl)
-			sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
-			sessionHandler.EXPECT().Get(connID).Return(sess, true)
-
-			serv.supportsTLS = true
-			b := &bytes.Buffer{}
-			hdr := &wire.Header{
-				IsLongHeader:     true,
-				Type:             protocol.PacketTypeHandshake,
-				PayloadLen:       1000,
-				SrcConnectionID:  connID,
-				DestConnectionID: connID,
-				PacketNumberLen:  protocol.PacketNumberLen1,
-				Version:          versionIETFFrames,
-			}
-			Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
-			err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
-			Expect(err).To(MatchError("packet payload (456 bytes) is smaller than the expected payload length (1000 bytes)"))
-		})
-
-		It("cuts packets at the payload length", func() {
-			sess := NewMockQuicSession(mockCtrl)
-			sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
-				Expect(packet.data).To(HaveLen(123))
-			})
-			sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
-			sessionHandler.EXPECT().Get(connID).Return(sess, true)
-
-			serv.supportsTLS = true
-			b := &bytes.Buffer{}
-			hdr := &wire.Header{
-				IsLongHeader:     true,
-				Type:             protocol.PacketTypeHandshake,
-				PayloadLen:       123,
-				SrcConnectionID:  connID,
-				DestConnectionID: connID,
-				PacketNumberLen:  protocol.PacketNumberLen1,
-				Version:          versionIETFFrames,
-			}
-			Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
-			err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
-			Expect(err).ToNot(HaveOccurred())
-		})
-
-		It("drops packets with invalid packet types", func() {
-			sess := NewMockQuicSession(mockCtrl)
-			sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
-			sessionHandler.EXPECT().Get(connID).Return(sess, true)
-
-			serv.supportsTLS = true
-			b := &bytes.Buffer{}
-			hdr := &wire.Header{
-				IsLongHeader:     true,
-				Type:             protocol.PacketTypeRetry,
-				PayloadLen:       123,
-				SrcConnectionID:  connID,
-				DestConnectionID: connID,
-				PacketNumberLen:  protocol.PacketNumberLen1,
-				Version:          versionIETFFrames,
-			}
-			Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
-			err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
-			Expect(err).To(MatchError("Received unsupported packet type: Retry"))
-		})
-
-		It("ignores Public Resets", func() {
-			sess := NewMockQuicSession(mockCtrl)
-			sess.EXPECT().GetVersion().Return(protocol.VersionTLS)
-			sessionHandler.EXPECT().Get(connID).Return(sess, true)
-
-			err := serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337))
-			Expect(err).ToNot(HaveOccurred())
+		It("returns Accept with the right error when closeWithError is called", func() {
+			testErr := errors.New("connection error")
+			done := make(chan struct{})
+			go func() {
+				defer GinkgoRecover()
+				_, err := serv.Accept()
+				Expect(err).To(MatchError(testErr))
+				close(done)
+			}()
+			sessionHandler.EXPECT().CloseServer()
+			serv.closeWithError(testErr)
+			Eventually(done).Should(BeClosed())
 		})
 
 		It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
 			config.Versions = []protocol.VersionNumber{99}
-			b := &bytes.Buffer{}
-			hdr := wire.Header{
-				VersionFlag:      true,
-				DestConnectionID: connID,
-				PacketNumber:     1,
-				PacketNumberLen:  protocol.PacketNumberLen2,
+			p := &receivedPacket{
+				header: &wire.Header{
+					VersionFlag:      true,
+					DestConnectionID: connID,
+					PacketNumber:     1,
+					PacketNumberLen:  protocol.PacketNumberLen2,
+				},
+				data: make([]byte, protocol.MinClientHelloSize),
 			}
-			Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed())
-			b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
-			serv.conn = conn
-			sessionHandler.EXPECT().Get(connID)
-			err := serv.handlePacket(nil, b.Bytes())
+			Expect(serv.handlePacketImpl(p)).To(Succeed())
 			Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
-			Expect(err).ToNot(HaveOccurred())
-		})
-
-		It("doesn't respond with a version negotiation packet if the first packet is too small", func() {
-			b := &bytes.Buffer{}
-			hdr := wire.Header{
-				VersionFlag:      true,
-				DestConnectionID: connID,
-				PacketNumber:     1,
-				PacketNumberLen:  protocol.PacketNumberLen2,
-			}
-			Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed())
-			b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small
-			serv.conn = conn
-			sessionHandler.EXPECT().Get(connID)
-			err := serv.handlePacket(udpAddr, b.Bytes())
-			Expect(err).To(MatchError("dropping small packet with unknown version"))
-			Expect(conn.dataWritten.Len()).Should(BeZero())
 		})
 	})
 
@@ -523,8 +396,11 @@ var _ = Describe("Server", func() {
 	})
 
 	It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
-		connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
 		config.Versions = append(config.Versions, protocol.VersionTLS)
+		ln, err := Listen(conn, testdata.GetTLSConfig(), config)
+		Expect(err).ToNot(HaveOccurred())
+
+		connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
 		b := &bytes.Buffer{}
 		hdr := wire.Header{
 			Type:             protocol.PacketTypeInitial,
@@ -536,13 +412,10 @@ var _ = Describe("Server", func() {
 			Version:          0x1234,
 			PayloadLen:       protocol.MinInitialPacketSize,
 		}
-		err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
-		Expect(err).ToNot(HaveOccurred())
+		Expect(hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)).To(Succeed())
 		b.Write(bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) // add a fake CHLO
 		conn.dataToRead <- b.Bytes()
 		conn.dataReadFrom = udpAddr
-		ln, err := Listen(conn, testdata.GetTLSConfig(), config)
-		Expect(err).ToNot(HaveOccurred())
 
 		done := make(chan struct{})
 		go func() {
@@ -568,51 +441,6 @@ var _ = Describe("Server", func() {
 		Eventually(done).Should(BeClosed())
 	})
 
-	It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() {
-		connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
-		b := &bytes.Buffer{}
-		hdr := wire.Header{
-			Type:             protocol.PacketTypeInitial,
-			IsLongHeader:     true,
-			DestConnectionID: connID,
-			SrcConnectionID:  connID,
-			PacketNumber:     0x55,
-			PacketNumberLen:  protocol.PacketNumberLen1,
-			Version:          protocol.VersionTLS,
-		}
-		err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
-		Expect(err).ToNot(HaveOccurred())
-		b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO
-		conn.dataToRead <- b.Bytes()
-		conn.dataReadFrom = udpAddr
-		ln, err := Listen(conn, testdata.GetTLSConfig(), config)
-		Expect(err).ToNot(HaveOccurred())
-		defer ln.Close()
-		Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
-	})
-
-	It("ignores non-Initial Long Header packets for unknown connections", func() {
-		connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
-		b := &bytes.Buffer{}
-		hdr := wire.Header{
-			Type:             protocol.PacketTypeHandshake,
-			IsLongHeader:     true,
-			DestConnectionID: connID,
-			SrcConnectionID:  connID,
-			PacketNumber:     0x55,
-			PacketNumberLen:  protocol.PacketNumberLen1,
-			Version:          protocol.VersionTLS,
-		}
-		err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
-		Expect(err).ToNot(HaveOccurred())
-		conn.dataToRead <- b.Bytes()
-		conn.dataReadFrom = udpAddr
-		ln, err := Listen(conn, testdata.GetTLSConfig(), config)
-		Expect(err).ToNot(HaveOccurred())
-		defer ln.Close()
-		Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
-	})
-
 	It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
 		conn.dataReadFrom = udpAddr
 		conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}

+ 3 - 3
server_tls.go

@@ -17,7 +17,7 @@ import (
 
 type tlsSession struct {
 	connID protocol.ConnectionID
-	sess   packetHandler
+	sess   quicSession
 }
 
 type serverTLS struct {
@@ -126,7 +126,7 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea
 	return err
 }
 
-func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, protocol.ConnectionID, error) {
+func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (quicSession, protocol.ConnectionID, error) {
 	if hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
 		return nil, nil, errors.New("dropping Initial packet with too short connection ID")
 	}
@@ -164,7 +164,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
 	return sess, connID, nil
 }
 
-func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, protocol.ConnectionID, error) {
+func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (quicSession, protocol.ConnectionID, error) {
 	version := hdr.Version
 	bc := handshake.NewCryptoStreamConn(remoteAddr)
 	bc.AddDataForReading(frame.Data)