wheat-cache/gateway/proxy/proxy.go

132 lines
3.7 KiB
Go

package proxy
import (
"context"
"io"
wheatCodec "gitee.com/wheat-os/wheatCache/gateway/codec"
"gitee.com/wheat-os/wheatCache/gateway/transport"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// TransparentHandler returns a handler that attempts to proxy all requests that are not registered in the server.
// The indented use here is as a transparent proxy, where the server doesn't know about the services implemented by the
// backends. It should be used as a `grpc.UnknownServiceHandler`.
//
// This can *only* be used if the `server` also uses grpcproxy.CodecForServer() ServerOption.
func TransparentHandler(director StreamDirector, tranport transport.TransPortInterface) grpc.StreamHandler {
streamer := &handler{
director,
tranport,
}
return streamer.handler
}
type handler struct {
director StreamDirector
transport transport.TransPortInterface
}
// handler is where the real magic of proxying happens.
// It is invoked like any gRPC server stream and uses the gRPC server framing to get and receive bytes from the wire,
// forwarding it to a ClientStream established against the relevant ClientConn.
func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error {
fullMethodName, ok := grpc.MethodFromServerStream(serverStream)
if !ok {
return status.Errorf(codes.Internal, "lowLevelServerStream not exists in context")
}
outgoingCtx, backendConn, err := s.director(serverStream.Context(), fullMethodName, s.transport)
if err != nil {
return err
}
clientCtx, clientCancel := context.WithCancel(outgoingCtx)
defer clientCancel()
clientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, backendConn, fullMethodName)
if err != nil {
return err
}
s2cErrChan := s.forwardServerToClient(serverStream, clientStream)
c2sErrChan := s.forwardClientToServer(clientStream, serverStream)
for i := 0; i < 2; i++ {
select {
case s2cErr := <-s2cErrChan:
if s2cErr == io.EOF {
// 客户端流发送完毕正常关闭结束, Proxy 关闭对 Backend 的连接
clientStream.CloseSend()
break
}
clientCancel()
return status.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr)
case c2sErr := <-c2sErrChan:
// 服务的没用在提供数据触发这个分支
serverStream.SetTrailer(clientStream.Trailer())
if c2sErr != io.EOF {
return c2sErr
}
return nil
}
}
return status.Errorf(codes.Internal, "gRPC proxying should never reach this stage.")
}
func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error {
ret := make(chan error, 1)
go func() {
f := &wheatCodec.Frame{}
for i := 0; ; i++ {
if err := src.RecvMsg(f); err != nil {
ret <- err // this can be io.EOF which is happy case
break
}
if i == 0 {
// This is a bit of a hack, but client to server headers are only readable after first client msg is
// received but must be written to server stream before the first msg is flushed.
// This is the only place to do it nicely.
md, err := src.Header()
if err != nil {
ret <- err
break
}
if err := dst.SendHeader(md); err != nil {
ret <- err
break
}
}
if err := dst.SendMsg(f); err != nil {
ret <- err
break
}
}
}()
return ret
}
func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream) chan error {
ret := make(chan error, 1)
go func() {
f := &wheatCodec.Frame{}
for i := 0; ; i++ {
if err := src.RecvMsg(f); err != nil {
ret <- err // this can be io.EOF which is happy case
break
}
if err := dst.SendMsg(f); err != nil {
ret <- err
break
}
}
}()
return ret
}