// Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "errors" "fmt" "sync" "github.com/uber/tchannel-go/typed" "go.uber.org/atomic" "golang.org/x/net/context" ) var ( errDuplicateMex = errors.New("multiple attempts to use the message id") errMexShutdown = errors.New("mex has been shutdown") errMexSetShutdown = errors.New("mexset has been shutdown") errMexChannelFull = NewSystemError(ErrCodeBusy, "cannot send frame to message exchange channel") errUnexpectedFrameType = errors.New("unexpected frame received") ) const ( messageExchangeSetInbound = "inbound" messageExchangeSetOutbound = "outbound" // mexChannelBufferSize is the size of the message exchange channel buffer. mexChannelBufferSize = 2 ) type errNotifier struct { c chan struct{} err error notified atomic.Bool } func newErrNotifier() errNotifier { return errNotifier{c: make(chan struct{})} } // Notify will store the error and notify all waiters on c that there's an error. func (e *errNotifier) Notify(err error) error { // The code should never try to Notify(nil). if err == nil { panic("cannot Notify with no error") } // There may be some sort of race where we try to notify the mex twice. if !e.notified.CAS(false, true) { return fmt.Errorf("cannot broadcast error: %v, already have: %v", err, e.err) } e.err = err close(e.c) return nil } // checkErr returns previously notified errors (if any). func (e *errNotifier) checkErr() error { select { case <-e.c: return e.err default: return nil } } // A messageExchange tracks this Connections's side of a message exchange with a // peer. Each message exchange has a channel that can be used to receive // frames from the peer, and a Context that can controls when the exchange has // timed out or been cancelled. type messageExchange struct { recvCh chan *Frame errCh errNotifier ctx context.Context msgID uint32 msgType messageType mexset *messageExchangeSet framePool FramePool shutdownAtomic atomic.Bool errChNotified atomic.Bool } // checkError is called before waiting on the mex channels. // It returns any existing errors (timeout, cancellation, connection errors). func (mex *messageExchange) checkError() error { if err := mex.ctx.Err(); err != nil { return GetContextError(err) } return mex.errCh.checkErr() } // forwardPeerFrame forwards a frame from a peer to the message exchange, where // it can be pulled by whatever application thread is handling the exchange func (mex *messageExchange) forwardPeerFrame(frame *Frame) error { // We want a very specific priority here: // 1. Timeouts/cancellation (mex.ctx errors) // 2. Whether recvCh has buffer space (non-blocking select over mex.recvCh) // 3. Other mex errors (mex.errCh) // Which is why we check the context error only (instead of mex.checkError). // In the mex.errCh case, we do a non-blocking write to recvCh to prioritize it. if err := mex.ctx.Err(); err != nil { return GetContextError(err) } select { case mex.recvCh <- frame: return nil case <-mex.ctx.Done(): // Note: One slow reader processing a large request could stall the connection. // If we see this, we need to increase the recvCh buffer size. return GetContextError(mex.ctx.Err()) case <-mex.errCh.c: // Select will randomly choose a case, but we want to prioritize // sending a frame over the errCh. Try a non-blocking write. select { case mex.recvCh <- frame: return nil default: } return mex.errCh.err } } func (mex *messageExchange) checkFrame(frame *Frame) error { if frame.Header.ID != mex.msgID { mex.mexset.log.WithFields( LogField{"msgId", mex.msgID}, LogField{"header", frame.Header}, ).Error("recvPeerFrame received msg with unexpected ID.") return errUnexpectedFrameType } return nil } // recvPeerFrame waits for a new frame from the peer, or until the context // expires or is cancelled func (mex *messageExchange) recvPeerFrame() (*Frame, error) { // We have to check frames/errors in a very specific order here: // 1. Timeouts/cancellation (mex.ctx errors) // 2. Any pending frames (non-blocking select over mex.recvCh) // 3. Other mex errors (mex.errCh) // Which is why we check the context error only (instead of mex.checkError)e // In the mex.errCh case, we do a non-blocking read from recvCh to prioritize it. if err := mex.ctx.Err(); err != nil { return nil, GetContextError(err) } select { case frame := <-mex.recvCh: if err := mex.checkFrame(frame); err != nil { return nil, err } return frame, nil case <-mex.ctx.Done(): return nil, GetContextError(mex.ctx.Err()) case <-mex.errCh.c: // Select will randomly choose a case, but we want to prioritize // receiving a frame over errCh. Try a non-blocking read. select { case frame := <-mex.recvCh: if err := mex.checkFrame(frame); err != nil { return nil, err } return frame, nil default: } return nil, mex.errCh.err } } // recvPeerFrameOfType waits for a new frame of a given type from the peer, failing // if the next frame received is not of that type. // If an error frame is returned, then the errorMessage is returned as the error. func (mex *messageExchange) recvPeerFrameOfType(msgType messageType) (*Frame, error) { frame, err := mex.recvPeerFrame() if err != nil { return nil, err } switch frame.Header.messageType { case msgType: return frame, nil case messageTypeError: // If we read an error frame, we can release it once we deserialize it. defer mex.framePool.Release(frame) errMsg := errorMessage{ id: frame.Header.ID, } var rbuf typed.ReadBuffer rbuf.Wrap(frame.SizedPayload()) if err := errMsg.read(&rbuf); err != nil { return nil, err } return nil, errMsg default: // TODO(mmihic): Should be treated as a protocol error mex.mexset.log.WithFields( LogField{"header", frame.Header}, LogField{"expectedType", msgType}, LogField{"expectedID", mex.msgID}, ).Warn("Received unexpected frame.") return nil, errUnexpectedFrameType } } // shutdown shuts down the message exchange, removing it from the message // exchange set so that it cannot receive more messages from the peer. The // receive channel remains open, however, in case there are concurrent // goroutines sending to it. func (mex *messageExchange) shutdown() { // The reader and writer side can both hit errors and try to shutdown the mex, // so we ensure that it's only shut down once. if !mex.shutdownAtomic.CAS(false, true) { return } if mex.errChNotified.CAS(false, true) { mex.errCh.Notify(errMexShutdown) } mex.mexset.removeExchange(mex.msgID) } // inboundExpired is called when an exchange is canceled or it times out, // but a handler may still be running in the background. Since the handler may // still write to the exchange, we cannot shutdown the exchange, but we should // remove it from the connection's exchange list. func (mex *messageExchange) inboundExpired() { mex.mexset.expireExchange(mex.msgID) } // A messageExchangeSet manages a set of active message exchanges. It is // mainly used to route frames from a peer to the appropriate messageExchange, // or to cancel or mark a messageExchange as being in error. Each Connection // maintains two messageExchangeSets, one to manage exchanges that it has // initiated (outbound), and another to manage exchanges that the peer has // initiated (inbound). The message-type specific handlers are responsible for // ensuring that their message exchanges are properly registered and removed // from the corresponding exchange set. type messageExchangeSet struct { sync.RWMutex log Logger name string onRemoved func() onAdded func() // maps are mutable, and are protected by the mutex. exchanges map[uint32]*messageExchange expiredExchanges map[uint32]struct{} shutdown bool } // newMessageExchangeSet creates a new messageExchangeSet with a given name. func newMessageExchangeSet(log Logger, name string) *messageExchangeSet { return &messageExchangeSet{ name: name, log: log.WithFields(LogField{"exchange", name}), exchanges: make(map[uint32]*messageExchange), expiredExchanges: make(map[uint32]struct{}), } } // addExchange adds an exchange, it must be called with the mexset locked. func (mexset *messageExchangeSet) addExchange(mex *messageExchange) error { if mexset.shutdown { return errMexSetShutdown } if _, ok := mexset.exchanges[mex.msgID]; ok { return errDuplicateMex } mexset.exchanges[mex.msgID] = mex return nil } // newExchange creates and adds a new message exchange to this set func (mexset *messageExchangeSet) newExchange(ctx context.Context, framePool FramePool, msgType messageType, msgID uint32, bufferSize int) (*messageExchange, error) { if mexset.log.Enabled(LogLevelDebug) { mexset.log.Debugf("Creating new %s message exchange for [%v:%d]", mexset.name, msgType, msgID) } mex := &messageExchange{ msgType: msgType, msgID: msgID, ctx: ctx, recvCh: make(chan *Frame, bufferSize), errCh: newErrNotifier(), mexset: mexset, framePool: framePool, } mexset.Lock() addErr := mexset.addExchange(mex) mexset.Unlock() if addErr != nil { logger := mexset.log.WithFields( LogField{"msgID", mex.msgID}, LogField{"msgType", mex.msgType}, LogField{"exchange", mexset.name}, ) if addErr == errMexSetShutdown { logger.Warn("Attempted to create new mex after mexset shutdown.") } else if addErr == errDuplicateMex { logger.Warn("Duplicate msg ID for active and new mex.") } return nil, addErr } mexset.onAdded() // TODO(mmihic): Put into a deadline ordered heap so we can garbage collected expired exchanges return mex, nil } // deleteExchange will delete msgID, and return whether it was found or whether it was // timed out. This method must be called with the lock. func (mexset *messageExchangeSet) deleteExchange(msgID uint32) (found, timedOut bool) { if _, found := mexset.exchanges[msgID]; found { delete(mexset.exchanges, msgID) return true, false } if _, expired := mexset.expiredExchanges[msgID]; expired { delete(mexset.expiredExchanges, msgID) return false, true } return false, false } // removeExchange removes a message exchange from the set, if it exists. func (mexset *messageExchangeSet) removeExchange(msgID uint32) { if mexset.log.Enabled(LogLevelDebug) { mexset.log.Debugf("Removing %s message exchange %d", mexset.name, msgID) } mexset.Lock() found, expired := mexset.deleteExchange(msgID) mexset.Unlock() if !found && !expired { mexset.log.WithFields( LogField{"msgID", msgID}, ).Error("Tried to remove exchange multiple times") return } // If the message exchange was found, then we perform clean up actions. // These clean up actions can only be run once per exchange. mexset.onRemoved() } // expireExchange is similar to removeExchange, but it marks the exchange as // expired. func (mexset *messageExchangeSet) expireExchange(msgID uint32) { mexset.log.Debugf( "Removing %s message exchange %d due to timeout, cancellation or blackhole", mexset.name, msgID, ) mexset.Lock() // TODO(aniketp): explore if cancel can be called everytime we expire an exchange found, expired := mexset.deleteExchange(msgID) if found || expired { // Record in expiredExchanges if we deleted the exchange. mexset.expiredExchanges[msgID] = struct{}{} } mexset.Unlock() if expired { mexset.log.WithFields(LogField{"msgID", msgID}).Info("Exchange expired already") } mexset.onRemoved() } func (mexset *messageExchangeSet) count() int { mexset.RLock() count := len(mexset.exchanges) mexset.RUnlock() return count } // forwardPeerFrame forwards a frame from the peer to the appropriate message // exchange func (mexset *messageExchangeSet) forwardPeerFrame(frame *Frame) error { if mexset.log.Enabled(LogLevelDebug) { mexset.log.Debugf("forwarding %s %s", mexset.name, frame.Header) } mexset.RLock() mex := mexset.exchanges[frame.Header.ID] mexset.RUnlock() if mex == nil { // This is ok since the exchange might have expired or been cancelled mexset.log.WithFields( LogField{"frameHeader", frame.Header.String()}, LogField{"exchange", mexset.name}, ).Info("Received frame for unknown message exchange.") return nil } if err := mex.forwardPeerFrame(frame); err != nil { mexset.log.WithFields( LogField{"frameHeader", frame.Header.String()}, LogField{"frameSize", frame.Header.FrameSize()}, LogField{"exchange", mexset.name}, ErrField(err), ).Info("Failed to forward frame.") return err } return nil } // copyExchanges returns a copy of the exchanges if the exchange is active. // The caller must lock the mexset. func (mexset *messageExchangeSet) copyExchanges() (shutdown bool, exchanges map[uint32]*messageExchange) { if mexset.shutdown { return true, nil } exchangesCopy := make(map[uint32]*messageExchange, len(mexset.exchanges)) for k, mex := range mexset.exchanges { exchangesCopy[k] = mex } return false, exchangesCopy } // stopExchanges stops all message exchanges to unblock all waiters on the mex. // This should only be called on connection failures. func (mexset *messageExchangeSet) stopExchanges(err error) { if mexset.log.Enabled(LogLevelDebug) { mexset.log.Debugf("stopping %v exchanges due to error: %v", mexset.count(), err) } mexset.Lock() shutdown, exchanges := mexset.copyExchanges() mexset.shutdown = true mexset.Unlock() if shutdown { mexset.log.Debugf("mexset has already been shutdown") return } for _, mex := range exchanges { // When there's a connection failure, we want to notify blocked callers that the // call will fail, but we don't want to shutdown the exchange as only the // arg reader/writer should shutdown the exchange. Otherwise, our guarantee // on sendChRefs that there's no references to sendCh is violated since // readers/writers could still have a reference to sendCh even though // we shutdown the exchange and called Done on sendChRefs. if mex.errChNotified.CAS(false, true) { mex.errCh.Notify(err) } } }