Browse Source

when the encryption level changes, reject data on that crypto stream

There are two checks that need to be performed:
1. the crypto stream must not have any more data queued for reading
2. when receiving CRYPTO frames for that crypto stream afterwards, they
must not exceed the highest offset received on that stream
Marten Seemann 6 months ago
parent
commit
387c28d707

+ 25 - 1
crypto_stream.go

@@ -1,6 +1,7 @@
 package quic
 
 import (
+	"errors"
 	"fmt"
 	"io"
 
@@ -13,6 +14,7 @@ type cryptoStream interface {
 	// for receiving data
 	HandleCryptoFrame(*wire.CryptoFrame) error
 	GetCryptoData() []byte
+	Finish() error
 	// for sending data
 	io.Writer
 	HasData() bool
@@ -23,6 +25,9 @@ type cryptoStreamImpl struct {
 	queue  *frameSorter
 	msgBuf []byte
 
+	highestOffset protocol.ByteCount
+	finished      bool
+
 	writeOffset protocol.ByteCount
 	writeBuf    []byte
 }
@@ -34,9 +39,20 @@ func newCryptoStream() cryptoStream {
 }
 
 func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
-	if maxOffset := f.Offset + protocol.ByteCount(len(f.Data)); maxOffset > protocol.MaxCryptoStreamOffset {
+	highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
+	if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
 		return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset)
 	}
+	if s.finished {
+		if highestOffset > s.highestOffset {
+			// reject crypto data received after this stream was already finished
+			return errors.New("received crypto data after change of encryption level")
+		}
+		// ignore data with a smaller offset than the highest received
+		// could e.g. be a retransmission
+		return nil
+	}
+	s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset)
 	if err := s.queue.Push(f.Data, f.Offset, false); err != nil {
 		return err
 	}
@@ -64,6 +80,14 @@ func (s *cryptoStreamImpl) GetCryptoData() []byte {
 	return msg
 }
 
+func (s *cryptoStreamImpl) Finish() error {
+	if s.queue.HasMoreData() {
+		return errors.New("encryption level changed, but crypto stream has more data to read")
+	}
+	s.finished = true
+	return nil
+}
+
 // Writes writes data that should be sent out in CRYPTO frames
 func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
 	s.writeBuf = append(s.writeBuf, p...)

+ 4 - 2
crypto_stream_manager.go

@@ -8,7 +8,7 @@ import (
 )
 
 type cryptoDataHandler interface {
-	HandleMessage([]byte, protocol.EncryptionLevel)
+	HandleMessage([]byte, protocol.EncryptionLevel) bool
 }
 
 type cryptoStreamManager struct {
@@ -48,6 +48,8 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
 		if data == nil {
 			return nil
 		}
-		m.cryptoHandler.HandleMessage(data, encLevel)
+		if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
+			return str.Finish()
+		}
 	}
 }

+ 26 - 0
crypto_stream_manager_test.go

@@ -1,6 +1,9 @@
 package quic
 
 import (
+	"errors"
+
+	"github.com/golang/mock/gomock"
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 	"github.com/lucas-clemente/quic-go/internal/wire"
 
@@ -61,6 +64,29 @@ var _ = Describe("Crypto Stream Manager", func() {
 		Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
 	})
 
+	It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() {
+		cf := &wire.CryptoFrame{Data: []byte("foobar")}
+		gomock.InOrder(
+			handshakeStream.EXPECT().HandleCryptoFrame(cf),
+			handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")),
+			cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
+			handshakeStream.EXPECT().Finish(),
+		)
+		Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
+	})
+
+	It("returns errors that occur when finishing a stream", func() {
+		testErr := errors.New("test error")
+		cf := &wire.CryptoFrame{Data: []byte("foobar")}
+		gomock.InOrder(
+			handshakeStream.EXPECT().HandleCryptoFrame(cf),
+			handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")),
+			cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
+			handshakeStream.EXPECT().Finish().Return(testErr),
+		)
+		Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(MatchError(testErr))
+	})
+
 	It("errors for unknown encryption levels", func() {
 		err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT)
 		Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT"))

+ 46 - 0
crypto_stream_test.go

@@ -89,6 +89,52 @@ var _ = Describe("Crypto Stream", func() {
 			Expect(str.GetCryptoData()).To(Equal(msg))
 			Expect(str.GetCryptoData()).To(BeNil())
 		})
+
+		Context("finishing", func() {
+			It("errors if there's still data to read after finishing", func() {
+				Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
+					Data:   createHandshakeMessage(5),
+					Offset: 10,
+				})).To(Succeed())
+				err := str.Finish()
+				Expect(err).To(MatchError("encryption level changed, but crypto stream has more data to read"))
+			})
+
+			It("works with reordered data", func() {
+				f1 := &wire.CryptoFrame{
+					Data: []byte("foo"),
+				}
+				f2 := &wire.CryptoFrame{
+					Offset: 3,
+					Data:   []byte("bar"),
+				}
+				Expect(str.HandleCryptoFrame(f2)).To(Succeed())
+				Expect(str.HandleCryptoFrame(f1)).To(Succeed())
+				Expect(str.Finish()).To(Succeed())
+				Expect(str.HandleCryptoFrame(f2)).To(Succeed())
+			})
+
+			It("rejects new crypto data after finishing", func() {
+				Expect(str.Finish()).To(Succeed())
+				err := str.HandleCryptoFrame(&wire.CryptoFrame{
+					Data: createHandshakeMessage(5),
+				})
+				Expect(err).To(MatchError("received crypto data after change of encryption level"))
+			})
+
+			It("ignores crypto data below the maximum offset received before finishing", func() {
+				msg := createHandshakeMessage(15)
+				Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
+					Data: msg,
+				})).To(Succeed())
+				Expect(str.GetCryptoData()).To(Equal(msg))
+				Expect(str.Finish()).To(Succeed())
+				Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
+					Offset: protocol.ByteCount(len(msg) - 6),
+					Data:   []byte("foobar"),
+				})).To(Succeed())
+			})
+		})
 	})
 
 	Context("writing data", func() {

+ 5 - 0
frame_sorter.go

@@ -156,3 +156,8 @@ func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
 	s.readPos += protocol.ByteCount(len(data))
 	return data, s.readPos >= s.finalOffset
 }
+
+// HasMoreData says if there is any more data queued at *any* offset.
+func (s *frameSorter) HasMoreData() bool {
+	return len(s.queue) > 0
+}

+ 9 - 0
frame_sorter_test.go

@@ -55,6 +55,15 @@ var _ = Describe("STREAM frame sorter", func() {
 			Expect(s.Pop()).To(BeNil())
 		})
 
+		It("says if has more data", func() {
+			Expect(s.HasMoreData()).To(BeFalse())
+			Expect(s.Push([]byte("foo"), 0, false)).To(Succeed())
+			Expect(s.HasMoreData()).To(BeTrue())
+			data, _ := s.Pop()
+			Expect(data).To(Equal([]byte("foo")))
+			Expect(s.HasMoreData()).To(BeFalse())
+		})
+
 		Context("FIN handling", func() {
 			It("saves a FIN at offset 0", func() {
 				Expect(s.Push(nil, 0, true)).To(Succeed())

+ 24 - 20
internal/handshake/crypto_setup_tls.go

@@ -271,19 +271,20 @@ func (h *cryptoSetupTLS) RunHandshake() error {
 
 // handleMessage handles a TLS handshake message.
 // It is called by the crypto streams when a new message is available.
-func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) {
+// It returns if it is done with messages on the same encryption level.
+func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
 	msgType := messageType(data[0])
 	h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
 	if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
 		h.messageErrChan <- err
-		return
+		return false
 	}
 	h.messageChan <- data
 	switch h.perspective {
 	case protocol.PerspectiveClient:
-		h.handleMessageForClient(msgType)
+		return h.handleMessageForClient(msgType)
 	case protocol.PerspectiveServer:
-		h.handleMessageForServer(msgType)
+		return h.handleMessageForServer(msgType)
 	default:
 		panic("")
 	}
@@ -310,78 +311,81 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot
 	return nil
 }
 
-func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) {
+func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
 	switch msgType {
 	case typeClientHello:
 		select {
 		case params := <-h.receivedTransportParams:
 			h.handleParamsCallback(&params)
 		case <-h.handshakeErrChan:
-			return
+			return false
 		}
 		// get the handshake write key
 		select {
 		case <-h.receivedWriteKey:
 		case <-h.handshakeErrChan:
-			return
+			return false
 		}
 		// get the 1-RTT write key
 		select {
 		case <-h.receivedWriteKey:
 		case <-h.handshakeErrChan:
-			return
+			return false
 		}
 		// get the handshake read key
 		// TODO: check that the initial stream doesn't have any more data
 		select {
 		case <-h.receivedReadKey:
 		case <-h.handshakeErrChan:
-			return
+			return false
 		}
 		h.handshakeEvent <- struct{}{}
+		return true
 	case typeCertificate, typeCertificateVerify:
 		// nothing to do
+		return false
 	case typeFinished:
 		// get the 1-RTT read key
-		// TODO: check that the handshake stream doesn't have any more data
 		select {
 		case <-h.receivedReadKey:
 		case <-h.handshakeErrChan:
-			return
+			return false
 		}
 		h.handshakeEvent <- struct{}{}
+		return true
 	default:
 		panic("unexpected handshake message")
 	}
 }
 
-func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) {
+func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
 	switch msgType {
 	case typeServerHello:
 		// get the handshake read key
-		// TODO: check that the initial stream doesn't have any more data
 		select {
 		case <-h.receivedReadKey:
 		case <-h.handshakeErrChan:
-			return
+			return false
 		}
 		h.handshakeEvent <- struct{}{}
+		return true
 	case typeEncryptedExtensions:
 		select {
 		case params := <-h.receivedTransportParams:
 			h.handleParamsCallback(&params)
 		case <-h.handshakeErrChan:
-			return
+			return false
 		}
+		return false
 	case typeCertificateRequest, typeCertificate, typeCertificateVerify:
 		// nothing to do
+		return false
 	case typeFinished:
 		// get the handshake write key
-		// TODO: check that the initial stream doesn't have any more data
 		select {
 		case <-h.receivedWriteKey:
 		case <-h.handshakeErrChan:
-			return
+			return false
 		}
 		// While the order of these two is not defined by the TLS spec,
 		// we have to do it on the same order as our TLS library does it.
@@ -389,16 +393,16 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) {
 		select {
 		case <-h.receivedWriteKey:
 		case <-h.handshakeErrChan:
-			return
+			return false
 		}
 		// get the 1-RTT read key
 		select {
 		case <-h.receivedReadKey:
 		case <-h.handshakeErrChan:
-			return
+			return false
 		}
-		// TODO: check that the handshake stream doesn't have any more data
 		h.handshakeEvent <- struct{}{}
+		return true
 	default:
 		panic("unexpected handshake message: ")
 	}

+ 1 - 1
internal/handshake/interface.go

@@ -44,7 +44,7 @@ type CryptoSetup interface {
 type CryptoSetupTLS interface {
 	baseCryptoSetup
 
-	HandleMessage([]byte, protocol.EncryptionLevel)
+	HandleMessage([]byte, protocol.EncryptionLevel) bool
 	OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
 	OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
 	Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)

+ 4 - 2
mock_crypto_data_handler.go

@@ -35,8 +35,10 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder {
 }
 
 // HandleMessage mocks base method
-func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) {
-	m.ctrl.Call(m, "HandleMessage", arg0, arg1)
+func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
+	ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
+	ret0, _ := ret[0].(bool)
+	return ret0
 }
 
 // HandleMessage indicates an expected call of HandleMessage

+ 12 - 0
mock_crypto_stream_test.go

@@ -35,6 +35,18 @@ func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder {
 	return m.recorder
 }
 
+// Finish mocks base method
+func (m *MockCryptoStream) Finish() error {
+	ret := m.ctrl.Call(m, "Finish")
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// Finish indicates an expected call of Finish
+func (mr *MockCryptoStreamMockRecorder) Finish() *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockCryptoStream)(nil).Finish))
+}
+
 // GetCryptoData mocks base method
 func (m *MockCryptoStream) GetCryptoData() []byte {
 	ret := m.ctrl.Call(m, "GetCryptoData")