streams_map.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package quic
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "github.com/lucas-clemente/quic-go/internal/flowcontrol"
  8. "github.com/lucas-clemente/quic-go/internal/handshake"
  9. "github.com/lucas-clemente/quic-go/internal/protocol"
  10. "github.com/lucas-clemente/quic-go/internal/qerr"
  11. "github.com/lucas-clemente/quic-go/internal/wire"
  12. )
  13. type streamError struct {
  14. message string
  15. nums []protocol.StreamNum
  16. }
  17. func (e streamError) Error() string {
  18. return e.message
  19. }
  20. func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
  21. strError, ok := err.(streamError)
  22. if !ok {
  23. return err
  24. }
  25. ids := make([]interface{}, len(strError.nums))
  26. for i, num := range strError.nums {
  27. ids[i] = num.StreamID(stype, pers)
  28. }
  29. return fmt.Errorf(strError.Error(), ids...)
  30. }
  31. type streamOpenErr struct{ error }
  32. var _ net.Error = &streamOpenErr{}
  33. func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
  34. func (streamOpenErr) Timeout() bool { return false }
  35. // errTooManyOpenStreams is used internally by the outgoing streams maps.
  36. var errTooManyOpenStreams = errors.New("too many open streams")
  37. type streamsMap struct {
  38. perspective protocol.Perspective
  39. sender streamSender
  40. newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
  41. outgoingBidiStreams *outgoingBidiStreamsMap
  42. outgoingUniStreams *outgoingUniStreamsMap
  43. incomingBidiStreams *incomingBidiStreamsMap
  44. incomingUniStreams *incomingUniStreamsMap
  45. }
  46. var _ streamManager = &streamsMap{}
  47. func newStreamsMap(
  48. sender streamSender,
  49. newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
  50. maxIncomingBidiStreams uint64,
  51. maxIncomingUniStreams uint64,
  52. perspective protocol.Perspective,
  53. version protocol.VersionNumber,
  54. ) streamManager {
  55. m := &streamsMap{
  56. perspective: perspective,
  57. newFlowController: newFlowController,
  58. sender: sender,
  59. }
  60. m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
  61. func(num protocol.StreamNum) streamI {
  62. id := num.StreamID(protocol.StreamTypeBidi, perspective)
  63. return newStream(id, m.sender, m.newFlowController(id), version)
  64. },
  65. sender.queueControlFrame,
  66. )
  67. m.incomingBidiStreams = newIncomingBidiStreamsMap(
  68. func(num protocol.StreamNum) streamI {
  69. id := num.StreamID(protocol.StreamTypeBidi, perspective.Opposite())
  70. return newStream(id, m.sender, m.newFlowController(id), version)
  71. },
  72. maxIncomingBidiStreams,
  73. sender.queueControlFrame,
  74. )
  75. m.outgoingUniStreams = newOutgoingUniStreamsMap(
  76. func(num protocol.StreamNum) sendStreamI {
  77. id := num.StreamID(protocol.StreamTypeUni, perspective)
  78. return newSendStream(id, m.sender, m.newFlowController(id), version)
  79. },
  80. sender.queueControlFrame,
  81. )
  82. m.incomingUniStreams = newIncomingUniStreamsMap(
  83. func(num protocol.StreamNum) receiveStreamI {
  84. id := num.StreamID(protocol.StreamTypeUni, perspective.Opposite())
  85. return newReceiveStream(id, m.sender, m.newFlowController(id), version)
  86. },
  87. maxIncomingUniStreams,
  88. sender.queueControlFrame,
  89. )
  90. return m
  91. }
  92. func (m *streamsMap) OpenStream() (Stream, error) {
  93. str, err := m.outgoingBidiStreams.OpenStream()
  94. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  95. }
  96. func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
  97. str, err := m.outgoingBidiStreams.OpenStreamSync(ctx)
  98. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  99. }
  100. func (m *streamsMap) OpenUniStream() (SendStream, error) {
  101. str, err := m.outgoingUniStreams.OpenStream()
  102. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  103. }
  104. func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
  105. str, err := m.outgoingUniStreams.OpenStreamSync(ctx)
  106. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  107. }
  108. func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
  109. str, err := m.incomingBidiStreams.AcceptStream(ctx)
  110. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
  111. }
  112. func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
  113. str, err := m.incomingUniStreams.AcceptStream(ctx)
  114. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
  115. }
  116. func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
  117. num := id.StreamNum()
  118. switch id.Type() {
  119. case protocol.StreamTypeUni:
  120. if id.InitiatedBy() == m.perspective {
  121. return m.outgoingUniStreams.DeleteStream(num)
  122. }
  123. return m.incomingUniStreams.DeleteStream(num)
  124. case protocol.StreamTypeBidi:
  125. if id.InitiatedBy() == m.perspective {
  126. return m.outgoingBidiStreams.DeleteStream(num)
  127. }
  128. return m.incomingBidiStreams.DeleteStream(num)
  129. }
  130. panic("")
  131. }
  132. func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
  133. str, err := m.getOrOpenReceiveStream(id)
  134. if err != nil {
  135. return nil, qerr.Error(qerr.StreamStateError, err.Error())
  136. }
  137. return str, nil
  138. }
  139. func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
  140. num := id.StreamNum()
  141. switch id.Type() {
  142. case protocol.StreamTypeUni:
  143. if id.InitiatedBy() == m.perspective {
  144. // an outgoing unidirectional stream is a send stream, not a receive stream
  145. return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
  146. }
  147. str, err := m.incomingUniStreams.GetOrOpenStream(num)
  148. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  149. case protocol.StreamTypeBidi:
  150. var str receiveStreamI
  151. var err error
  152. if id.InitiatedBy() == m.perspective {
  153. str, err = m.outgoingBidiStreams.GetStream(num)
  154. } else {
  155. str, err = m.incomingBidiStreams.GetOrOpenStream(num)
  156. }
  157. return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
  158. }
  159. panic("")
  160. }
  161. func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
  162. str, err := m.getOrOpenSendStream(id)
  163. if err != nil {
  164. return nil, qerr.Error(qerr.StreamStateError, err.Error())
  165. }
  166. return str, nil
  167. }
  168. func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
  169. num := id.StreamNum()
  170. switch id.Type() {
  171. case protocol.StreamTypeUni:
  172. if id.InitiatedBy() == m.perspective {
  173. str, err := m.outgoingUniStreams.GetStream(num)
  174. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  175. }
  176. // an incoming unidirectional stream is a receive stream, not a send stream
  177. return nil, fmt.Errorf("peer attempted to open send stream %d", id)
  178. case protocol.StreamTypeBidi:
  179. var str sendStreamI
  180. var err error
  181. if id.InitiatedBy() == m.perspective {
  182. str, err = m.outgoingBidiStreams.GetStream(num)
  183. } else {
  184. str, err = m.incomingBidiStreams.GetOrOpenStream(num)
  185. }
  186. return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
  187. }
  188. panic("")
  189. }
  190. func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error {
  191. if f.MaxStreamNum > protocol.MaxStreamCount {
  192. return qerr.StreamLimitError
  193. }
  194. switch f.Type {
  195. case protocol.StreamTypeUni:
  196. m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
  197. case protocol.StreamTypeBidi:
  198. m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
  199. }
  200. return nil
  201. }
  202. func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) error {
  203. if p.MaxBidiStreamNum > protocol.MaxStreamCount ||
  204. p.MaxUniStreamNum > protocol.MaxStreamCount {
  205. return qerr.StreamLimitError
  206. }
  207. // Max{Uni,Bidi}StreamID returns the highest stream ID that the peer is allowed to open.
  208. m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
  209. m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
  210. return nil
  211. }
  212. func (m *streamsMap) CloseWithError(err error) {
  213. m.outgoingBidiStreams.CloseWithError(err)
  214. m.outgoingUniStreams.CloseWithError(err)
  215. m.incomingBidiStreams.CloseWithError(err)
  216. m.incomingUniStreams.CloseWithError(err)
  217. }