Browse Source

queue stream-level window updates from the flow controller directly

Marten Seemann 1 year ago
parent
commit
2e8a5807ba

+ 2 - 2
internal/flowcontrol/interface.go

@@ -21,8 +21,8 @@ type StreamFlowController interface {
 	// UpdateHighestReceived should be called when a new highest offset is received
 	// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
 	UpdateHighestReceived(offset protocol.ByteCount, final bool) error
-	// HasWindowUpdate says if it is necessary to update the window
-	HasWindowUpdate() bool
+	// MaybeQueueWindowUpdate queues a window update, if necessary
+	MaybeQueueWindowUpdate()
 }
 
 // The ConnectionFlowController is the flow controller for the connection.

+ 8 - 2
internal/flowcontrol/stream_flow_controller.go

@@ -14,6 +14,8 @@ type streamFlowController struct {
 
 	streamID protocol.StreamID
 
+	queueWindowUpdate func()
+
 	connection              connectionFlowControllerI
 	contributesToConnection bool // does the stream contribute to connection level flow control
 
@@ -30,6 +32,7 @@ func NewStreamFlowController(
 	receiveWindow protocol.ByteCount,
 	maxReceiveWindow protocol.ByteCount,
 	initialSendWindow protocol.ByteCount,
+	queueWindowUpdate func(protocol.StreamID),
 	rttStats *congestion.RTTStats,
 	logger utils.Logger,
 ) StreamFlowController {
@@ -37,6 +40,7 @@ func NewStreamFlowController(
 		streamID:                streamID,
 		contributesToConnection: contributesToConnection,
 		connection:              cfc.(connectionFlowControllerI),
+		queueWindowUpdate:       func() { queueWindowUpdate(streamID) },
 		baseFlowController: baseFlowController{
 			rttStats:             rttStats,
 			receiveWindow:        receiveWindow,
@@ -120,11 +124,13 @@ func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
 	return true, c.sendWindow
 }
 
-func (c *streamFlowController) HasWindowUpdate() bool {
+func (c *streamFlowController) MaybeQueueWindowUpdate() {
 	c.mutex.Lock()
 	hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
 	c.mutex.Unlock()
-	return hasWindowUpdate
+	if hasWindowUpdate {
+		c.queueWindowUpdate()
+	}
 }
 
 func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {

+ 34 - 11
internal/flowcontrol/stream_flow_controller_test.go

@@ -12,9 +12,13 @@ import (
 )
 
 var _ = Describe("Stream Flow controller", func() {
-	var controller *streamFlowController
+	var (
+		controller         *streamFlowController
+		queuedWindowUpdate bool
+	)
 
 	BeforeEach(func() {
+		queuedWindowUpdate = false
 		rttStats := &congestion.RTTStats{}
 		controller = &streamFlowController{
 			streamID:   10,
@@ -23,24 +27,38 @@ var _ = Describe("Stream Flow controller", func() {
 		controller.maxReceiveWindowSize = 10000
 		controller.rttStats = rttStats
 		controller.logger = utils.DefaultLogger
+		controller.queueWindowUpdate = func() { queuedWindowUpdate = true }
 	})
 
 	Context("Constructor", func() {
 		rttStats := &congestion.RTTStats{}
+		receiveWindow := protocol.ByteCount(2000)
+		maxReceiveWindow := protocol.ByteCount(3000)
+		sendWindow := protocol.ByteCount(4000)
 
 		It("sets the send and receive windows", func() {
-			receiveWindow := protocol.ByteCount(2000)
-			maxReceiveWindow := protocol.ByteCount(3000)
-			sendWindow := protocol.ByteCount(4000)
-
 			cc := NewConnectionFlowController(0, 0, nil, utils.DefaultLogger)
-			fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats, utils.DefaultLogger).(*streamFlowController)
+			fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController)
 			Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
 			Expect(fc.receiveWindow).To(Equal(receiveWindow))
 			Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
 			Expect(fc.sendWindow).To(Equal(sendWindow))
 			Expect(fc.contributesToConnection).To(BeTrue())
 		})
+
+		It("queues window updates with the correction stream ID", func() {
+			var queued bool
+			queueWindowUpdate := func(id protocol.StreamID) {
+				Expect(id).To(Equal(protocol.StreamID(5)))
+				queued = true
+			}
+
+			cc := NewConnectionFlowController(0, 0, nil, utils.DefaultLogger)
+			fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController)
+			fc.AddBytesRead(receiveWindow)
+			fc.MaybeQueueWindowUpdate()
+			Expect(queued).To(BeTrue())
+		})
 	})
 
 	Context("receiving data", func() {
@@ -175,12 +193,16 @@ var _ = Describe("Stream Flow controller", func() {
 				oldWindowSize = controller.receiveWindowSize
 			})
 
-			It("tells if it has window updates", func() {
-				Expect(controller.HasWindowUpdate()).To(BeFalse())
+			It("queues window updates", func() {
+				controller.MaybeQueueWindowUpdate()
+				Expect(queuedWindowUpdate).To(BeFalse())
 				controller.AddBytesRead(30)
-				Expect(controller.HasWindowUpdate()).To(BeTrue())
+				controller.MaybeQueueWindowUpdate()
+				Expect(queuedWindowUpdate).To(BeTrue())
 				Expect(controller.GetWindowUpdate()).ToNot(BeZero())
-				Expect(controller.HasWindowUpdate()).To(BeFalse())
+				queuedWindowUpdate = false
+				controller.MaybeQueueWindowUpdate()
+				Expect(queuedWindowUpdate).To(BeFalse())
 			})
 
 			It("tells the connection flow controller when the window was autotuned", func() {
@@ -213,7 +235,8 @@ var _ = Describe("Stream Flow controller", func() {
 				controller.AddBytesRead(30)
 				err := controller.UpdateHighestReceived(90, true)
 				Expect(err).ToNot(HaveOccurred())
-				Expect(controller.HasWindowUpdate()).To(BeFalse())
+				controller.MaybeQueueWindowUpdate()
+				Expect(queuedWindowUpdate).To(BeFalse())
 				offset := controller.GetWindowUpdate()
 				Expect(offset).To(BeZero())
 			})

+ 10 - 12
internal/mocks/stream_flow_controller.go

@@ -66,18 +66,6 @@ func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call {
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate))
 }
 
-// HasWindowUpdate mocks base method
-func (m *MockStreamFlowController) HasWindowUpdate() bool {
-	ret := m.ctrl.Call(m, "HasWindowUpdate")
-	ret0, _ := ret[0].(bool)
-	return ret0
-}
-
-// HasWindowUpdate indicates an expected call of HasWindowUpdate
-func (mr *MockStreamFlowControllerMockRecorder) HasWindowUpdate() *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).HasWindowUpdate))
-}
-
 // IsBlocked mocks base method
 func (m *MockStreamFlowController) IsBlocked() (bool, protocol.ByteCount) {
 	ret := m.ctrl.Call(m, "IsBlocked")
@@ -91,6 +79,16 @@ func (mr *MockStreamFlowControllerMockRecorder) IsBlocked() *gomock.Call {
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsBlocked))
 }
 
+// MaybeQueueWindowUpdate mocks base method
+func (m *MockStreamFlowController) MaybeQueueWindowUpdate() {
+	m.ctrl.Call(m, "MaybeQueueWindowUpdate")
+}
+
+// MaybeQueueWindowUpdate indicates an expected call of MaybeQueueWindowUpdate
+func (mr *MockStreamFlowControllerMockRecorder) MaybeQueueWindowUpdate() *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeQueueWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).MaybeQueueWindowUpdate))
+}
+
 // SendWindowSize mocks base method
 func (m *MockStreamFlowController) SendWindowSize() protocol.ByteCount {
 	ret := m.ctrl.Call(m, "SendWindowSize")

+ 0 - 10
mock_stream_sender_test.go

@@ -45,16 +45,6 @@ func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomoc
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0)
 }
 
-// onHasWindowUpdate mocks base method
-func (m *MockStreamSender) onHasWindowUpdate(arg0 protocol.StreamID) {
-	m.ctrl.Call(m, "onHasWindowUpdate", arg0)
-}
-
-// onHasWindowUpdate indicates an expected call of onHasWindowUpdate
-func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0)
-}
-
 // onStreamCompleted mocks base method
 func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) {
 	m.ctrl.Call(m, "onStreamCompleted", arg0)

+ 2 - 4
receive_stream.go

@@ -151,10 +151,8 @@ func (s *receiveStream) Read(p []byte) (int, error) {
 		if !s.resetRemotely {
 			s.flowController.AddBytesRead(protocol.ByteCount(m))
 		}
-		// this call triggers the flow controller to increase the flow control window, if necessary
-		if s.flowController.HasWindowUpdate() {
-			s.sender.onHasWindowUpdate(s.streamID)
-		}
+		// increase the flow control window, if necessary
+		s.flowController.MaybeQueueWindowUpdate()
 
 		if s.readPosInFrame >= int(frame.DataLen()) {
 			s.frameQueue.Pop()

+ 15 - 31
receive_stream_test.go

@@ -43,7 +43,7 @@ var _ = Describe("Receive Stream", func() {
 		It("reads a single STREAM frame", func() {
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
 			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
-			mockFC.EXPECT().HasWindowUpdate()
+			mockFC.EXPECT().MaybeQueueWindowUpdate()
 			frame := wire.StreamFrame{
 				Offset: 0,
 				Data:   []byte{0xDE, 0xAD, 0xBE, 0xEF},
@@ -61,7 +61,7 @@ var _ = Describe("Receive Stream", func() {
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
 			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
 			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
-			mockFC.EXPECT().HasWindowUpdate().Times(2)
+			mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2)
 			frame := wire.StreamFrame{
 				Offset: 0,
 				Data:   []byte{0xDE, 0xAD, 0xBE, 0xEF},
@@ -83,7 +83,7 @@ var _ = Describe("Receive Stream", func() {
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
 			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
-			mockFC.EXPECT().HasWindowUpdate().Times(2)
+			mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2)
 			frame1 := wire.StreamFrame{
 				Offset: 0,
 				Data:   []byte{0xDE, 0xAD},
@@ -107,7 +107,7 @@ var _ = Describe("Receive Stream", func() {
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
 			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
-			mockFC.EXPECT().HasWindowUpdate().Times(2)
+			mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2)
 			frame1 := wire.StreamFrame{
 				Offset: 0,
 				Data:   []byte{0xDE, 0xAD},
@@ -130,7 +130,7 @@ var _ = Describe("Receive Stream", func() {
 		It("waits until data is available", func() {
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
 			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
-			mockFC.EXPECT().HasWindowUpdate()
+			mockFC.EXPECT().MaybeQueueWindowUpdate()
 			go func() {
 				defer GinkgoRecover()
 				frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}}
@@ -148,7 +148,7 @@ var _ = Describe("Receive Stream", func() {
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
 			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
-			mockFC.EXPECT().HasWindowUpdate().Times(2)
+			mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2)
 			frame1 := wire.StreamFrame{
 				Offset: 2,
 				Data:   []byte{0xBE, 0xEF},
@@ -173,7 +173,7 @@ var _ = Describe("Receive Stream", func() {
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
 			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
-			mockFC.EXPECT().HasWindowUpdate().Times(2)
+			mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2)
 			frame1 := wire.StreamFrame{
 				Offset: 0,
 				Data:   []byte{0xDE, 0xAD},
@@ -204,7 +204,7 @@ var _ = Describe("Receive Stream", func() {
 			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false)
 			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
 			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
-			mockFC.EXPECT().HasWindowUpdate().Times(2)
+			mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2)
 			frame1 := wire.StreamFrame{
 				Offset: 0,
 				Data:   []byte("foob"),
@@ -230,22 +230,6 @@ var _ = Describe("Receive Stream", func() {
 			Expect(err).To(MatchError(errEmptyStreamData))
 		})
 
-		It("calls the onHasWindowUpdate callback, when the a MAX_STREAM_DATA should be sent", func() {
-			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false)
-			mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
-			mockFC.EXPECT().HasWindowUpdate().Return(true)
-			mockSender.EXPECT().onHasWindowUpdate(streamID)
-			frame1 := wire.StreamFrame{
-				Offset: 0,
-				Data:   []byte("foobar"),
-			}
-			err := str.handleStreamFrame(&frame1)
-			Expect(err).ToNot(HaveOccurred())
-			b := make([]byte, 6)
-			_, err = strWithTimeout.Read(b)
-			Expect(err).ToNot(HaveOccurred())
-		})
-
 		Context("deadlines", func() {
 			It("the deadline error has the right net.Error properties", func() {
 				Expect(errDeadline.Temporary()).To(BeTrue())
@@ -318,7 +302,7 @@ var _ = Describe("Receive Stream", func() {
 				It("returns EOFs", func() {
 					mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
 					mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
-					mockFC.EXPECT().HasWindowUpdate()
+					mockFC.EXPECT().MaybeQueueWindowUpdate()
 					str.handleStreamFrame(&wire.StreamFrame{
 						Offset: 0,
 						Data:   []byte{0xDE, 0xAD, 0xBE, 0xEF},
@@ -339,7 +323,7 @@ var _ = Describe("Receive Stream", func() {
 					mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
 					mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
 					mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
-					mockFC.EXPECT().HasWindowUpdate().Times(2)
+					mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2)
 					frame1 := wire.StreamFrame{
 						Offset: 2,
 						Data:   []byte{0xBE, 0xEF},
@@ -367,7 +351,7 @@ var _ = Describe("Receive Stream", func() {
 				It("returns EOFs with partial read", func() {
 					mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true)
 					mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
-					mockFC.EXPECT().HasWindowUpdate()
+					mockFC.EXPECT().MaybeQueueWindowUpdate()
 					err := str.handleStreamFrame(&wire.StreamFrame{
 						Offset: 0,
 						Data:   []byte{0xde, 0xad},
@@ -385,7 +369,7 @@ var _ = Describe("Receive Stream", func() {
 				It("handles immediate FINs", func() {
 					mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
 					mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
-					mockFC.EXPECT().HasWindowUpdate()
+					mockFC.EXPECT().MaybeQueueWindowUpdate()
 					err := str.handleStreamFrame(&wire.StreamFrame{
 						Offset: 0,
 						FinBit: true,
@@ -402,7 +386,7 @@ var _ = Describe("Receive Stream", func() {
 			It("closes when CloseRemote is called", func() {
 				mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
 				mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
-				mockFC.EXPECT().HasWindowUpdate()
+				mockFC.EXPECT().MaybeQueueWindowUpdate()
 				str.CloseRemote(0)
 				mockSender.EXPECT().onStreamCompleted(streamID)
 				b := make([]byte, 8)
@@ -478,7 +462,7 @@ var _ = Describe("Receive Stream", func() {
 			It("doesn't send a RST_STREAM frame, if the FIN was already read", func() {
 				mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
 				mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
-				mockFC.EXPECT().HasWindowUpdate()
+				mockFC.EXPECT().MaybeQueueWindowUpdate()
 				// no calls to mockSender.queueControlFrame
 				err := str.handleStreamFrame(&wire.StreamFrame{
 					StreamID: streamID,
@@ -601,7 +585,7 @@ var _ = Describe("Receive Stream", func() {
 						mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)),
 						mockSender.EXPECT().onStreamCompleted(streamID),
 					)
-					mockFC.EXPECT().HasWindowUpdate().Times(2)
+					mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2)
 					readReturned := make(chan struct{})
 					go func() {
 						defer GinkgoRecover()

+ 2 - 0
session.go

@@ -1137,6 +1137,7 @@ func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlow
 		protocol.ReceiveStreamFlowControlWindow,
 		protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow),
 		initialSendWindow,
+		s.onHasWindowUpdate,
 		s.rttStats,
 		s.logger,
 	)
@@ -1151,6 +1152,7 @@ func (s *session) newCryptoStream() cryptoStreamI {
 		protocol.ReceiveStreamFlowControlWindow,
 		protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow),
 		0,
+		s.onHasWindowUpdate,
 		s.rttStats,
 		s.logger,
 	)

+ 0 - 5
stream.go

@@ -18,7 +18,6 @@ const (
 // The streamSender is notified by the stream about various events.
 type streamSender interface {
 	queueControlFrame(wire.Frame)
-	onHasWindowUpdate(protocol.StreamID)
 	onHasStreamData(protocol.StreamID)
 	onStreamCompleted(protocol.StreamID)
 }
@@ -34,10 +33,6 @@ func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
 	s.streamSender.queueControlFrame(f)
 }
 
-func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) {
-	s.streamSender.onHasWindowUpdate(id)
-}
-
 func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
 	s.streamSender.onHasStreamData(id)
 }