Browse Source

return the Opener from the crypto setup

Marten Seemann 8 months ago
parent
commit
67f923c736

+ 16 - 15
internal/handshake/crypto_setup.go

@@ -493,22 +493,23 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve
 	}
 }
 
-func (h *cryptoSetup) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
-	return h.initialOpener.Open(dst, src, pn, ad)
-}
-
-func (h *cryptoSetup) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
-	if h.handshakeOpener == nil {
-		return nil, errors.New("no handshake opener")
-	}
-	return h.handshakeOpener.Open(dst, src, pn, ad)
-}
-
-func (h *cryptoSetup) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
-	if h.opener == nil {
-		return nil, errors.New("no 1-RTT opener")
+func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) {
+	switch level {
+	case protocol.EncryptionInitial:
+		return h.initialOpener, nil
+	case protocol.EncryptionHandshake:
+		if h.handshakeOpener == nil {
+			return nil, errors.New("CryptoSetup: no opener with encryption level Handshake")
+		}
+		return h.handshakeOpener, nil
+	case protocol.Encryption1RTT:
+		if h.opener == nil {
+			return nil, errors.New("CryptoSetup: no opener with encryption level 1-RTT")
+		}
+		return h.opener, nil
+	default:
+		return nil, fmt.Errorf("CryptoSetup: no opener with encryption level %s", level)
 	}
-	return h.opener.Open(dst, src, pn, ad)
 }
 
 func (h *cryptoSetup) ConnectionState() ConnectionState {

+ 1 - 4
internal/handshake/interface.go

@@ -35,10 +35,7 @@ type CryptoSetup interface {
 
 	GetSealer() (protocol.EncryptionLevel, Sealer)
 	GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
-
-	OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
-	OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
-	Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
+	GetOpener(protocol.EncryptionLevel) (Opener, error)
 }
 
 // ConnectionState records basic details about the QUIC connection.

+ 13 - 39
internal/mocks/crypto_setup.go

@@ -59,6 +59,19 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
 }
 
+// GetOpener mocks base method
+func (m *MockCryptoSetup) GetOpener(arg0 protocol.EncryptionLevel) (handshake.Opener, error) {
+	ret := m.ctrl.Call(m, "GetOpener", arg0)
+	ret0, _ := ret[0].(handshake.Opener)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetOpener indicates an expected call of GetOpener
+func (mr *MockCryptoSetupMockRecorder) GetOpener(arg0 interface{}) *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetOpener), arg0)
+}
+
 // GetSealer mocks base method
 func (m *MockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) {
 	ret := m.ctrl.Call(m, "GetSealer")
@@ -97,45 +110,6 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
 }
 
-// Open1RTT mocks base method
-func (m *MockCryptoSetup) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
-	ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3)
-	ret0, _ := ret[0].([]byte)
-	ret1, _ := ret[1].(error)
-	return ret0, ret1
-}
-
-// Open1RTT indicates an expected call of Open1RTT
-func (mr *MockCryptoSetupMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockCryptoSetup)(nil).Open1RTT), arg0, arg1, arg2, arg3)
-}
-
-// OpenHandshake mocks base method
-func (m *MockCryptoSetup) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
-	ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3)
-	ret0, _ := ret[0].([]byte)
-	ret1, _ := ret[1].(error)
-	return ret0, ret1
-}
-
-// OpenHandshake indicates an expected call of OpenHandshake
-func (mr *MockCryptoSetupMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).OpenHandshake), arg0, arg1, arg2, arg3)
-}
-
-// OpenInitial mocks base method
-func (m *MockCryptoSetup) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
-	ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3)
-	ret0, _ := ret[0].([]byte)
-	ret1, _ := ret[1].(error)
-	return ret0, ret1
-}
-
-// OpenInitial indicates an expected call of OpenInitial
-func (mr *MockCryptoSetupMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockCryptoSetup)(nil).OpenInitial), arg0, arg1, arg2, arg3)
-}
-
 // RunHandshake mocks base method
 func (m *MockCryptoSetup) RunHandshake() error {
 	ret := m.ctrl.Call(m, "RunHandshake")

+ 1 - 0
internal/mocks/mockgen.go

@@ -1,6 +1,7 @@
 package mocks
 
 //go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
+//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener"
 //go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
 //go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
 //go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"

+ 48 - 0
internal/mocks/opener.go

@@ -0,0 +1,48 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Opener)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+	reflect "reflect"
+
+	gomock "github.com/golang/mock/gomock"
+	protocol "github.com/lucas-clemente/quic-go/internal/protocol"
+)
+
+// MockOpener is a mock of Opener interface
+type MockOpener struct {
+	ctrl     *gomock.Controller
+	recorder *MockOpenerMockRecorder
+}
+
+// MockOpenerMockRecorder is the mock recorder for MockOpener
+type MockOpenerMockRecorder struct {
+	mock *MockOpener
+}
+
+// NewMockOpener creates a new mock instance
+func NewMockOpener(ctrl *gomock.Controller) *MockOpener {
+	mock := &MockOpener{ctrl: ctrl}
+	mock.recorder = &MockOpenerMockRecorder{mock}
+	return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockOpener) EXPECT() *MockOpenerMockRecorder {
+	return m.recorder
+}
+
+// Open mocks base method
+func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
+	ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3)
+	ret0, _ := ret[0].([]byte)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// Open indicates an expected call of Open
+func (mr *MockOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockOpener)(nil).Open), arg0, arg1, arg2, arg3)
+}

+ 0 - 74
mock_quic_aead_test.go

@@ -1,74 +0,0 @@
-// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/lucas-clemente/quic-go (interfaces: QuicAEAD)
-
-// Package quic is a generated GoMock package.
-package quic
-
-import (
-	reflect "reflect"
-
-	gomock "github.com/golang/mock/gomock"
-	protocol "github.com/lucas-clemente/quic-go/internal/protocol"
-)
-
-// MockQuicAEAD is a mock of QuicAEAD interface
-type MockQuicAEAD struct {
-	ctrl     *gomock.Controller
-	recorder *MockQuicAEADMockRecorder
-}
-
-// MockQuicAEADMockRecorder is the mock recorder for MockQuicAEAD
-type MockQuicAEADMockRecorder struct {
-	mock *MockQuicAEAD
-}
-
-// NewMockQuicAEAD creates a new mock instance
-func NewMockQuicAEAD(ctrl *gomock.Controller) *MockQuicAEAD {
-	mock := &MockQuicAEAD{ctrl: ctrl}
-	mock.recorder = &MockQuicAEADMockRecorder{mock}
-	return mock
-}
-
-// EXPECT returns an object that allows the caller to indicate expected use
-func (m *MockQuicAEAD) EXPECT() *MockQuicAEADMockRecorder {
-	return m.recorder
-}
-
-// Open1RTT mocks base method
-func (m *MockQuicAEAD) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
-	ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3)
-	ret0, _ := ret[0].([]byte)
-	ret1, _ := ret[1].(error)
-	return ret0, ret1
-}
-
-// Open1RTT indicates an expected call of Open1RTT
-func (mr *MockQuicAEADMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockQuicAEAD)(nil).Open1RTT), arg0, arg1, arg2, arg3)
-}
-
-// OpenHandshake mocks base method
-func (m *MockQuicAEAD) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
-	ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3)
-	ret0, _ := ret[0].([]byte)
-	ret1, _ := ret[1].(error)
-	return ret0, ret1
-}
-
-// OpenHandshake indicates an expected call of OpenHandshake
-func (mr *MockQuicAEADMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockQuicAEAD)(nil).OpenHandshake), arg0, arg1, arg2, arg3)
-}
-
-// OpenInitial mocks base method
-func (m *MockQuicAEAD) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
-	ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3)
-	ret0, _ := ret[0].([]byte)
-	ret1, _ := ret[1].(error)
-	return ret0, ret1
-}
-
-// OpenInitial indicates an expected call of OpenInitial
-func (mr *MockQuicAEADMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockQuicAEAD)(nil).OpenInitial), arg0, arg1, arg2, arg3)
-}

+ 0 - 1
mockgen.go

@@ -13,7 +13,6 @@ package quic
 //go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/lucas-clemente/quic-go sealingManager"
 //go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker"
 //go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer"
-//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD"
 //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner"
 //go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession"
 //go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler"

+ 14 - 18
packet_unpacker.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"fmt"
 
+	"github.com/lucas-clemente/quic-go/internal/handshake"
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 	"github.com/lucas-clemente/quic-go/internal/qerr"
 	"github.com/lucas-clemente/quic-go/internal/utils"
@@ -17,15 +18,9 @@ type unpackedPacket struct {
 	frames          []wire.Frame
 }
 
-type quicAEAD interface {
-	OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
-	OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
-	Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
-}
-
 // The packetUnpacker unpacks QUIC packets.
 type packetUnpacker struct {
-	aead quicAEAD
+	cs handshake.CryptoSetup
 
 	largestRcvdPacketNumber protocol.PacketNumber
 
@@ -34,9 +29,9 @@ type packetUnpacker struct {
 
 var _ unpacker = &packetUnpacker{}
 
-func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
+func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) unpacker {
 	return &packetUnpacker{
-		aead:    aead,
+		cs:      cs,
 		version: version,
 	}
 }
@@ -69,22 +64,23 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
 	buf = buf[:0]
 	defer putPacketBuffer(&buf)
 
-	var decrypted []byte
-	var encryptionLevel protocol.EncryptionLevel
+	var encLevel protocol.EncryptionLevel
 	switch hdr.Type {
 	case protocol.PacketTypeInitial:
-		decrypted, err = u.aead.OpenInitial(buf, data, pn, extHdr.Raw)
-		encryptionLevel = protocol.EncryptionInitial
+		encLevel = protocol.EncryptionInitial
 	case protocol.PacketTypeHandshake:
-		decrypted, err = u.aead.OpenHandshake(buf, data, pn, extHdr.Raw)
-		encryptionLevel = protocol.EncryptionHandshake
+		encLevel = protocol.EncryptionHandshake
 	default:
 		if hdr.IsLongHeader {
 			return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
 		}
-		decrypted, err = u.aead.Open1RTT(buf, data, pn, extHdr.Raw)
-		encryptionLevel = protocol.Encryption1RTT
+		encLevel = protocol.Encryption1RTT
+	}
+	opener, err := u.cs.GetOpener(encLevel)
+	if err != nil {
+		return nil, qerr.Error(qerr.DecryptionFailure, err.Error())
 	}
+	decrypted, err := opener.Open(buf, data, pn, extHdr.Raw)
 	if err != nil {
 		return nil, qerr.Error(qerr.DecryptionFailure, err.Error())
 	}
@@ -100,7 +96,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
 	return &unpackedPacket{
 		hdr:             extHdr,
 		packetNumber:    pn,
-		encryptionLevel: encryptionLevel,
+		encryptionLevel: encLevel,
 		frames:          fs,
 	}, nil
 }

+ 35 - 29
packet_unpacker_test.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 
 	"github.com/golang/mock/gomock"
+	"github.com/lucas-clemente/quic-go/internal/mocks"
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 	"github.com/lucas-clemente/quic-go/internal/qerr"
 	"github.com/lucas-clemente/quic-go/internal/wire"
@@ -17,7 +18,7 @@ var _ = Describe("Packet Unpacker", func() {
 	const version = protocol.VersionTLS
 	var (
 		unpacker *packetUnpacker
-		aead     *MockQuicAEAD
+		cs       *mocks.MockCryptoSetup
 		connID   = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
 	)
 
@@ -30,8 +31,8 @@ var _ = Describe("Packet Unpacker", func() {
 	}
 
 	BeforeEach(func() {
-		aead = NewMockQuicAEAD(mockCtrl)
-		unpacker = newPacketUnpacker(aead, version).(*packetUnpacker)
+		cs = mocks.NewMockCryptoSetup(mockCtrl)
+		unpacker = newPacketUnpacker(cs, version).(*packetUnpacker)
 	})
 
 	It("errors if the packet doesn't contain any payload", func() {
@@ -43,7 +44,9 @@ var _ = Describe("Packet Unpacker", func() {
 		hdr, hdrRaw := getHeader(extHdr)
 		data := append(hdrRaw, []byte("foobar")...) // add some payload
 		// return an empty (unencrypted) payload
-		aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{}, nil)
+		opener := mocks.NewMockOpener(mockCtrl)
+		cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil)
+		opener.EXPECT().Open(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{}, nil)
 		_, err := unpacker.Unpack(hdr, data)
 		Expect(err).To(MatchError(qerr.MissingPayload))
 	})
@@ -61,31 +64,14 @@ var _ = Describe("Packet Unpacker", func() {
 			PacketNumberLen: 3,
 		}
 		hdr, hdrRaw := getHeader(extHdr)
-		aead.EXPECT().OpenInitial(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil)
+		opener := mocks.NewMockOpener(mockCtrl)
+		cs.EXPECT().GetOpener(protocol.EncryptionInitial).Return(opener, nil)
+		opener.EXPECT().Open(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil)
 		packet, err := unpacker.Unpack(hdr, append(hdrRaw, []byte("foobar")...))
 		Expect(err).ToNot(HaveOccurred())
 		Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
 	})
 
-	It("opens Handshake packets", func() {
-		extHdr := &wire.ExtendedHeader{
-			Header: wire.Header{
-				IsLongHeader:     true,
-				Type:             protocol.PacketTypeHandshake,
-				Length:           3 + 6, // packet number len + payload
-				DestConnectionID: connID,
-				Version:          version,
-			},
-			PacketNumber:    2,
-			PacketNumberLen: 3,
-		}
-		hdr, hdrRaw := getHeader(extHdr)
-		aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil)
-		packet, err := unpacker.Unpack(hdr, append(hdrRaw, []byte("foobar")...))
-		Expect(err).ToNot(HaveOccurred())
-		Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionHandshake))
-	})
-
 	It("errors on packets that are smaller than the length in the packet header", func() {
 		extHdr := &wire.ExtendedHeader{
 			Header: wire.Header{
@@ -118,7 +104,9 @@ var _ = Describe("Packet Unpacker", func() {
 		payloadLen := 456 - int(pnLen)
 		hdr, hdrRaw := getHeader(extHdr)
 		data := append(hdrRaw, make([]byte, payloadLen)...)
-		aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) {
+		opener := mocks.NewMockOpener(mockCtrl)
+		cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil)
+		opener.EXPECT().Open(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) {
 			Expect(payload).To(HaveLen(payloadLen))
 			return []byte{0}, nil
 		})
@@ -126,6 +114,18 @@ var _ = Describe("Packet Unpacker", func() {
 		Expect(err).ToNot(HaveOccurred())
 	})
 
+	It("returns the error when getting the sealer fails", func() {
+		extHdr := &wire.ExtendedHeader{
+			Header:          wire.Header{DestConnectionID: connID},
+			PacketNumber:    0x1337,
+			PacketNumberLen: 2,
+		}
+		hdr, hdrRaw := getHeader(extHdr)
+		cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(nil, errors.New("test err"))
+		_, err := unpacker.Unpack(hdr, hdrRaw)
+		Expect(err).To(MatchError(qerr.Error(qerr.DecryptionFailure, "test err")))
+	})
+
 	It("returns the error when unpacking fails", func() {
 		extHdr := &wire.ExtendedHeader{
 			Header: wire.Header{
@@ -139,7 +139,9 @@ var _ = Describe("Packet Unpacker", func() {
 			PacketNumberLen: 3,
 		}
 		hdr, hdrRaw := getHeader(extHdr)
-		aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
+		opener := mocks.NewMockOpener(mockCtrl)
+		cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil)
+		opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
 		_, err := unpacker.Unpack(hdr, hdrRaw)
 		Expect(err).To(MatchError(qerr.Error(qerr.DecryptionFailure, "test err")))
 	})
@@ -150,7 +152,9 @@ var _ = Describe("Packet Unpacker", func() {
 			PacketNumber:    0x1337,
 			PacketNumberLen: 2,
 		}
-		aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil)
+		opener := mocks.NewMockOpener(mockCtrl)
+		cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil).Times(2)
+		opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil)
 		packet, err := unpacker.Unpack(getHeader(firstHdr))
 		Expect(err).ToNot(HaveOccurred())
 		Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
@@ -161,7 +165,7 @@ var _ = Describe("Packet Unpacker", func() {
 			PacketNumberLen: 1,
 		}
 		// expect the call with the decoded packet number
-		aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil)
+		opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil)
 		packet, err = unpacker.Unpack(getHeader(secondHdr))
 		Expect(err).ToNot(HaveOccurred())
 		Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338)))
@@ -177,7 +181,9 @@ var _ = Describe("Packet Unpacker", func() {
 		(&wire.PingFrame{}).Write(buf, protocol.VersionWhatever)
 		(&wire.DataBlockedFrame{}).Write(buf, protocol.VersionWhatever)
 		hdr, hdrRaw := getHeader(extHdr)
-		aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return(buf.Bytes(), nil)
+		opener := mocks.NewMockOpener(mockCtrl)
+		cs.EXPECT().GetOpener(protocol.Encryption1RTT).Return(opener, nil)
+		opener.EXPECT().Open(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return(buf.Bytes(), nil)
 		packet, err := unpacker.Unpack(hdr, append(hdrRaw, buf.Bytes()...))
 		Expect(err).ToNot(HaveOccurred())
 		Expect(packet.frames).To(Equal([]wire.Frame{&wire.PingFrame{}, &wire.DataBlockedFrame{}}))