Browse Source

move parsing of the extended header to the unpacker

Marten Seemann 4 months ago
parent
commit
aaea375fb6
5 changed files with 165 additions and 117 deletions
  1. 1 1
      mock_unpacker_test.go
  2. 28 8
      packet_unpacker.go
  3. 129 24
      packet_unpacker_test.go
  4. 7 28
      session.go
  5. 0 56
      session_test.go

+ 1 - 1
mock_unpacker_test.go

@@ -35,7 +35,7 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder {
 }
 
 // Unpack mocks base method
-func (m *MockUnpacker) Unpack(arg0 *wire.ExtendedHeader, arg1 []byte) (*unpackedPacket, error) {
+func (m *MockUnpacker) Unpack(arg0 *wire.Header, arg1 []byte) (*unpackedPacket, error) {
 	ret := m.ctrl.Call(m, "Unpack", arg0, arg1)
 	ret0, _ := ret[0].(*unpackedPacket)
 	ret1, _ := ret[1].(error)

+ 28 - 8
packet_unpacker.go

@@ -11,7 +11,8 @@ import (
 )
 
 type unpackedPacket struct {
-	packetNumber    protocol.PacketNumber
+	packetNumber    protocol.PacketNumber // the decoded packet number
+	hdr             *wire.ExtendedHeader
 	encryptionLevel protocol.EncryptionLevel
 	frames          []wire.Frame
 }
@@ -40,11 +41,30 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
 	}
 }
 
-func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
+func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
+	r := bytes.NewReader(data)
+	extHdr, err := hdr.ParseExtended(r, u.version)
+	if err != nil {
+		return nil, fmt.Errorf("error parsing extended header: %s", err)
+	}
+	extHdr.Raw = data[:len(data)-r.Len()]
+	data = data[len(data)-r.Len():]
+
+	if hdr.IsLongHeader {
+		if hdr.Length < protocol.ByteCount(extHdr.PacketNumberLen) {
+			return nil, fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", extHdr.Length, extHdr.PacketNumberLen)
+		}
+		if protocol.ByteCount(len(data))+protocol.ByteCount(extHdr.PacketNumberLen) < extHdr.Length {
+			return nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(extHdr.PacketNumberLen), extHdr.Length)
+		}
+		data = data[:int(extHdr.Length)-int(extHdr.PacketNumberLen)]
+		// TODO(#1312): implement parsing of compound packets
+	}
+
 	pn := protocol.DecodePacketNumber(
-		hdr.PacketNumberLen,
+		extHdr.PacketNumberLen,
 		u.largestRcvdPacketNumber,
-		hdr.PacketNumber,
+		extHdr.PacketNumber,
 	)
 
 	buf := *getPacketBuffer()
@@ -53,19 +73,18 @@ func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpacke
 
 	var decrypted []byte
 	var encryptionLevel protocol.EncryptionLevel
-	var err error
 	switch hdr.Type {
 	case protocol.PacketTypeInitial:
-		decrypted, err = u.aead.OpenInitial(buf, data, pn, hdr.Raw)
+		decrypted, err = u.aead.OpenInitial(buf, data, pn, extHdr.Raw)
 		encryptionLevel = protocol.EncryptionInitial
 	case protocol.PacketTypeHandshake:
-		decrypted, err = u.aead.OpenHandshake(buf, data, pn, hdr.Raw)
+		decrypted, err = u.aead.OpenHandshake(buf, data, pn, extHdr.Raw)
 		encryptionLevel = protocol.EncryptionHandshake
 	default:
 		if hdr.IsLongHeader {
 			return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
 		}
-		decrypted, err = u.aead.Open1RTT(buf, data, pn, hdr.Raw)
+		decrypted, err = u.aead.Open1RTT(buf, data, pn, extHdr.Raw)
 		encryptionLevel = protocol.Encryption1RTT
 	}
 	if err != nil {
@@ -81,6 +100,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpacke
 	}
 
 	return &unpackedPacket{
+		hdr:             extHdr,
 		packetNumber:    pn,
 		encryptionLevel: encryptionLevel,
 		frames:          fs,

+ 129 - 24
packet_unpacker_test.go

@@ -14,82 +14,187 @@ import (
 )
 
 var _ = Describe("Packet Unpacker", func() {
+	const version = protocol.VersionTLS
 	var (
 		unpacker *packetUnpacker
-		hdr      *wire.ExtendedHeader
 		aead     *MockQuicAEAD
+		connID   = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
 	)
 
+	getHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
+		buf := &bytes.Buffer{}
+		Expect(extHdr.Write(buf, protocol.VersionWhatever)).To(Succeed())
+		hdr, err := wire.ParseHeader(bytes.NewReader(buf.Bytes()), connID.Len())
+		Expect(err).ToNot(HaveOccurred())
+		return hdr, buf.Bytes()
+	}
+
 	BeforeEach(func() {
 		aead = NewMockQuicAEAD(mockCtrl)
-		hdr = &wire.ExtendedHeader{
-			PacketNumber:    10,
-			PacketNumberLen: 1,
-			Raw:             []byte{0x04, 0x4c, 0x01},
-		}
-		unpacker = newPacketUnpacker(aead, protocol.VersionWhatever).(*packetUnpacker)
+		unpacker = newPacketUnpacker(aead, version).(*packetUnpacker)
 	})
 
 	It("errors if the packet doesn't contain any payload", func() {
-		data := []byte("foobar")
-		aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), hdr.PacketNumber, hdr.Raw).Return([]byte{}, nil)
+		extHdr := &wire.ExtendedHeader{
+			Header:          wire.Header{DestConnectionID: connID},
+			PacketNumber:    42,
+			PacketNumberLen: protocol.PacketNumberLen2,
+		}
+		hdr, hdrRaw := getHeader(extHdr)
+		data := append(hdrRaw, []byte("foobar")...) // add some payload
+		// return an empty (unencrypted) payload
+		aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{}, nil)
 		_, err := unpacker.Unpack(hdr, data)
 		Expect(err).To(MatchError(qerr.MissingPayload))
 	})
 
 	It("opens Initial packets", func() {
-		hdr.IsLongHeader = true
-		hdr.Type = protocol.PacketTypeInitial
-		aead.EXPECT().OpenInitial(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil)
-		packet, err := unpacker.Unpack(hdr, nil)
+		extHdr := &wire.ExtendedHeader{
+			Header: wire.Header{
+				IsLongHeader:     true,
+				Type:             protocol.PacketTypeInitial,
+				Length:           3 + 6, // packet number len + payload
+				DestConnectionID: connID,
+				Version:          version,
+			},
+			PacketNumber:    2,
+			PacketNumberLen: 3,
+		}
+		hdr, hdrRaw := getHeader(extHdr)
+		aead.EXPECT().OpenInitial(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil)
+		packet, err := unpacker.Unpack(hdr, append(hdrRaw, []byte("foobar")...))
 		Expect(err).ToNot(HaveOccurred())
 		Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
 	})
 
 	It("opens Handshake packets", func() {
-		hdr.IsLongHeader = true
-		hdr.Type = protocol.PacketTypeHandshake
-		aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil)
-		packet, err := unpacker.Unpack(hdr, nil)
+		extHdr := &wire.ExtendedHeader{
+			Header: wire.Header{
+				IsLongHeader:     true,
+				Type:             protocol.PacketTypeHandshake,
+				Length:           3 + 6, // packet number len + payload
+				DestConnectionID: connID,
+				Version:          version,
+			},
+			PacketNumber:    2,
+			PacketNumberLen: 3,
+		}
+		hdr, hdrRaw := getHeader(extHdr)
+		aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil)
+		packet, err := unpacker.Unpack(hdr, append(hdrRaw, []byte("foobar")...))
 		Expect(err).ToNot(HaveOccurred())
 		Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionHandshake))
 	})
 
+	It("errors on packets that are smaller than the length in the packet header", func() {
+		extHdr := &wire.ExtendedHeader{
+			Header: wire.Header{
+				IsLongHeader:     true,
+				Type:             protocol.PacketTypeHandshake,
+				Length:           1000,
+				DestConnectionID: connID,
+				Version:          version,
+			},
+			PacketNumberLen: protocol.PacketNumberLen2,
+		}
+		hdr, hdrRaw := getHeader(extHdr)
+		data := append(hdrRaw, make([]byte, 500-2 /* for packet number length */)...)
+		_, err := unpacker.Unpack(hdr, data)
+		Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
+	})
+
+	It("errors when receiving a packet that has a length smaller than the packet number length", func() {
+		extHdr := &wire.ExtendedHeader{
+			Header: wire.Header{
+				IsLongHeader:     true,
+				DestConnectionID: connID,
+				Type:             protocol.PacketTypeHandshake,
+				Length:           3,
+				Version:          protocol.VersionTLS,
+			},
+			PacketNumberLen: protocol.PacketNumberLen4,
+		}
+		hdr, hdrRaw := getHeader(extHdr)
+		_, err := unpacker.Unpack(hdr, hdrRaw)
+		Expect(err).To(MatchError("packet length (3 bytes) shorter than packet number (4 bytes)"))
+	})
+
+	It("cuts packets to the right length", func() {
+		pnLen := protocol.PacketNumberLen2
+		extHdr := &wire.ExtendedHeader{
+			Header: wire.Header{
+				IsLongHeader:     true,
+				DestConnectionID: connID,
+				Type:             protocol.PacketTypeHandshake,
+				Length:           456,
+				Version:          protocol.VersionTLS,
+			},
+			PacketNumberLen: pnLen,
+		}
+		payloadLen := 456 - int(pnLen)
+		hdr, hdrRaw := getHeader(extHdr)
+		data := append(hdrRaw, make([]byte, payloadLen)...)
+		aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) {
+			Expect(payload).To(HaveLen(payloadLen))
+			return []byte{0}, nil
+		})
+		_, err := unpacker.Unpack(hdr, data)
+		Expect(err).ToNot(HaveOccurred())
+	})
+
 	It("returns the error when unpacking fails", func() {
-		hdr.IsLongHeader = true
-		hdr.Type = protocol.PacketTypeHandshake
+		extHdr := &wire.ExtendedHeader{
+			Header: wire.Header{
+				IsLongHeader:     true,
+				Type:             protocol.PacketTypeHandshake,
+				Length:           3, // packet number len
+				DestConnectionID: connID,
+				Version:          version,
+			},
+			PacketNumber:    2,
+			PacketNumberLen: 3,
+		}
+		hdr, hdrRaw := getHeader(extHdr)
 		aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
-		_, err := unpacker.Unpack(hdr, nil)
+		_, err := unpacker.Unpack(hdr, hdrRaw)
 		Expect(err).To(MatchError(qerr.Error(qerr.DecryptionFailure, "test err")))
 	})
 
 	It("decodes the packet number", func() {
 		firstHdr := &wire.ExtendedHeader{
+			Header:          wire.Header{DestConnectionID: connID},
 			PacketNumber:    0x1337,
 			PacketNumberLen: 2,
 		}
 		aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil)
-		packet, err := unpacker.Unpack(firstHdr, nil)
+		packet, err := unpacker.Unpack(getHeader(firstHdr))
 		Expect(err).ToNot(HaveOccurred())
 		Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
 		// the real packet number is 0x1338, but only the last byte is sent
 		secondHdr := &wire.ExtendedHeader{
+			Header:          wire.Header{DestConnectionID: connID},
 			PacketNumber:    0x38,
 			PacketNumberLen: 1,
 		}
 		// expect the call with the decoded packet number
 		aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil)
-		packet, err = unpacker.Unpack(secondHdr, nil)
+		packet, err = unpacker.Unpack(getHeader(secondHdr))
 		Expect(err).ToNot(HaveOccurred())
 		Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338)))
 	})
 
 	It("unpacks the frames", func() {
+		extHdr := &wire.ExtendedHeader{
+			Header:          wire.Header{DestConnectionID: connID},
+			PacketNumber:    0x1337,
+			PacketNumberLen: 2,
+		}
 		buf := &bytes.Buffer{}
 		(&wire.PingFrame{}).Write(buf, protocol.VersionWhatever)
 		(&wire.DataBlockedFrame{}).Write(buf, protocol.VersionWhatever)
-		aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return(buf.Bytes(), nil)
-		packet, err := unpacker.Unpack(hdr, nil)
+		hdr, hdrRaw := getHeader(extHdr)
+		aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return(buf.Bytes(), nil)
+		packet, err := unpacker.Unpack(hdr, append(hdrRaw, buf.Bytes()...))
 		Expect(err).ToNot(HaveOccurred())
 		Expect(packet.frames).To(Equal([]wire.Frame{&wire.PingFrame{}, &wire.DataBlockedFrame{}}))
 	})

+ 7 - 28
session.go

@@ -1,7 +1,6 @@
 package quic
 
 import (
-	"bytes"
 	"context"
 	"crypto/tls"
 	"errors"
@@ -22,7 +21,7 @@ import (
 )
 
 type unpacker interface {
-	Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
+	Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error)
 }
 
 type streamGetter interface {
@@ -483,27 +482,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
 		return nil
 	}
 
-	data := p.data
-	r := bytes.NewReader(data)
-	hdr, err := p.hdr.ParseExtended(r, s.version)
-	if err != nil {
-		return fmt.Errorf("error parsing extended header: %s", err)
-	}
-	hdr.Raw = data[:len(data)-r.Len()]
-	data = data[len(data)-r.Len():]
-
-	if hdr.IsLongHeader {
-		if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
-			return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
-		}
-		if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
-			return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length)
-		}
-		data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)]
-		// TODO(#1312): implement parsing of compound packets
-	}
-
-	packet, err := s.unpacker.Unpack(hdr, data)
+	packet, err := s.unpacker.Unpack(p.hdr, p.data)
 	// if the decryption failed, this might be a packet sent by an attacker
 	if err != nil {
 		return err
@@ -511,13 +490,13 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
 
 	if s.logger.Debug() {
 		s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), p.hdr.DestConnectionID, packet.encryptionLevel)
-		hdr.Log(s.logger)
+		packet.hdr.Log(s.logger)
 	}
 
 	// The server can change the source connection ID with the first Handshake packet.
-	if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
-		s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", hdr.SrcConnectionID)
-		s.destConnID = hdr.SrcConnectionID
+	if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
+		s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", p.hdr.SrcConnectionID)
+		s.destConnID = p.hdr.SrcConnectionID
 		s.packer.ChangeDestConnectionID(s.destConnID)
 	}
 
@@ -536,7 +515,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
 
 	// If this is a Retry packet, there's no need to send an ACK.
 	// The session will be closed and recreated as soon as the crypto setup processed the HRR.
-	if hdr.Type != protocol.PacketTypeRetry {
+	if p.hdr.Type != protocol.PacketTypeRetry {
 		isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames)
 		if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, p.rcvTime, isRetransmittable); err != nil {
 			return err

+ 0 - 56
session_test.go

@@ -533,62 +533,6 @@ var _ = Describe("Session", func() {
 			})).To(Succeed())
 		})
 
-		It("errors on packets that are smaller than the length in the packet header", func() {
-			connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
-			hdr := &wire.ExtendedHeader{
-				Header: wire.Header{
-					IsLongHeader:     true,
-					Type:             protocol.PacketTypeHandshake,
-					Length:           1000,
-					DestConnectionID: connID,
-					Version:          protocol.VersionTLS,
-				},
-				PacketNumberLen: protocol.PacketNumberLen2,
-			}
-			data := getData(hdr)
-			data = append(data, make([]byte, 500-2 /* for packet number length */)...)
-			Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
-		})
-
-		It("errors when receiving a packet that has a length smaller than the packet number length", func() {
-			connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
-			hdr := &wire.ExtendedHeader{
-				Header: wire.Header{
-					IsLongHeader:     true,
-					DestConnectionID: connID,
-					Type:             protocol.PacketTypeHandshake,
-					Length:           3,
-					Version:          protocol.VersionTLS,
-				},
-				PacketNumberLen: protocol.PacketNumberLen4,
-			}
-			data := getData(hdr)
-			Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(MatchError("packet length (3 bytes) shorter than packet number (4 bytes)"))
-		})
-
-		It("cuts packets to the right length", func() {
-			connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
-			pnLen := protocol.PacketNumberLen2
-			hdr := &wire.ExtendedHeader{
-				Header: wire.Header{
-					IsLongHeader:     true,
-					DestConnectionID: connID,
-					Type:             protocol.PacketTypeHandshake,
-					Length:           456,
-					Version:          protocol.VersionTLS,
-				},
-				PacketNumberLen: pnLen,
-			}
-			payloadLen := 456 - int(pnLen)
-			data := getData(hdr)
-			data = append(data, make([]byte, payloadLen)...)
-			unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
-				Expect(data).To(HaveLen(payloadLen))
-				return &unpackedPacket{}, nil
-			})
-			Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(Succeed())
-		})
-
 		Context("updating the remote address", func() {
 			It("doesn't support connection migration", func() {
 				unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)