diff --git a/gateway/proxy/define.go b/gateway/proxy/define.go new file mode 100644 index 0000000..0167d65 --- /dev/null +++ b/gateway/proxy/define.go @@ -0,0 +1,16 @@ +package proxy + +import ( + "context" + + "google.golang.org/grpc" +) + +type StreamDirector func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) + +var ( + clientStreamDescForProxying = &grpc.StreamDesc{ + ServerStreams: true, + ClientStreams: true, + } +) diff --git a/gateway/proxy/director.go b/gateway/proxy/director.go new file mode 100644 index 0000000..9a99dab --- /dev/null +++ b/gateway/proxy/director.go @@ -0,0 +1,16 @@ +package proxy + +import ( + "context" + + "gitee.com/timedb/wheatCache/gateway/codec" + "google.golang.org/grpc" +) + +func GetDirectorByServiceHash() StreamDirector { + return func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) { + // TODO hash, mock 直接转发到 storage dev 上 + cli, err := grpc.DialContext(ctx, "127.0.0.1:5890", grpc.WithInsecure(), grpc.ForceCodec(codec.Codec())) + return ctx, cli, err + } +} diff --git a/gateway/proxy/proxy.go b/gateway/proxy/proxy.go new file mode 100644 index 0000000..df475d5 --- /dev/null +++ b/gateway/proxy/proxy.go @@ -0,0 +1,125 @@ +package proxy + +import ( + "context" + "io" + + wheatCodec "gitee.com/timedb/wheatCache/gateway/codec" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +// TransparentHandler returns a handler that attempts to proxy all requests that are not registered in the server. +// The indented use here is as a transparent proxy, where the server doesn't know about the services implemented by the +// backends. It should be used as a `grpc.UnknownServiceHandler`. +// +// This can *only* be used if the `server` also uses grpcproxy.CodecForServer() ServerOption. +func TransparentHandler(director StreamDirector) grpc.StreamHandler { + streamer := &handler{director} + return streamer.handler +} + +type handler struct { + director StreamDirector +} + +// handler is where the real magic of proxying happens. +// It is invoked like any gRPC server stream and uses the gRPC server framing to get and receive bytes from the wire, +// forwarding it to a ClientStream established against the relevant ClientConn. +func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error { + fullMethodName, ok := grpc.MethodFromServerStream(serverStream) + if !ok { + return grpc.Errorf(codes.Internal, "lowLevelServerStream not exists in context") + } + + outgoingCtx, backendConn, err := s.director(serverStream.Context(), fullMethodName) + if err != nil { + return err + } + + clientCtx, clientCancel := context.WithCancel(outgoingCtx) + defer clientCancel() + + clientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, backendConn, fullMethodName) + if err != nil { + return err + } + + s2cErrChan := s.forwardServerToClient(serverStream, clientStream) + c2sErrChan := s.forwardClientToServer(clientStream, serverStream) + + for i := 0; i < 2; i++ { + select { + case s2cErr := <-s2cErrChan: + if s2cErr == io.EOF { + // 客户端流发送完毕正常关闭结束, Proxy 关闭对 Backend 的连接 + clientStream.CloseSend() + break + } + + clientCancel() + return grpc.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr) + case c2sErr := <-c2sErrChan: + // 服务的没用在提供数据触发这个分支 + serverStream.SetTrailer(clientStream.Trailer()) + + if c2sErr != io.EOF { + return c2sErr + } + + return nil + } + } + + return grpc.Errorf(codes.Internal, "gRPC proxying should never reach this stage.") +} + +func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error { + ret := make(chan error, 1) + go func() { + f := &wheatCodec.Frame{} + for i := 0; ; i++ { + if err := src.RecvMsg(f); err != nil { + ret <- err // this can be io.EOF which is happy case + break + } + if i == 0 { + // This is a bit of a hack, but client to server headers are only readable after first client msg is + // received but must be written to server stream before the first msg is flushed. + // This is the only place to do it nicely. + md, err := src.Header() + if err != nil { + ret <- err + break + } + if err := dst.SendHeader(md); err != nil { + ret <- err + break + } + } + if err := dst.SendMsg(f); err != nil { + ret <- err + break + } + } + }() + return ret +} + +func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream) chan error { + ret := make(chan error, 1) + go func() { + f := &wheatCodec.Frame{} + for i := 0; ; i++ { + if err := src.RecvMsg(f); err != nil { + ret <- err // this can be io.EOF which is happy case + break + } + if err := dst.SendMsg(f); err != nil { + ret <- err + break + } + } + }() + return ret +}