vendor: update grpc-go v1.0.2 tag

Fix https://github.com/coreos/etcd/issues/6529.
This commit is contained in:
Gyu-Ho Lee 2016-10-10 11:05:14 -07:00
parent 69ea359e62
commit e3558a64cf
14 changed files with 306 additions and 153 deletions

View File

@ -38,6 +38,7 @@ import (
"sync" "sync"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/naming" "google.golang.org/grpc/naming"
) )
@ -52,6 +53,14 @@ type Address struct {
Metadata interface{} Metadata interface{}
} }
// BalancerConfig specifies the configurations for Balancer.
type BalancerConfig struct {
// DialCreds is the transport credential the Balancer implementation can
// use to dial to a remote load balancer server. The Balancer implementations
// can ignore this if it does not need to talk to another party securely.
DialCreds credentials.TransportCredentials
}
// BalancerGetOptions configures a Get call. // BalancerGetOptions configures a Get call.
// This is the EXPERIMENTAL API and may be changed or extended in the future. // This is the EXPERIMENTAL API and may be changed or extended in the future.
type BalancerGetOptions struct { type BalancerGetOptions struct {
@ -66,11 +75,11 @@ type Balancer interface {
// Start does the initialization work to bootstrap a Balancer. For example, // Start does the initialization work to bootstrap a Balancer. For example,
// this function may start the name resolution and watch the updates. It will // this function may start the name resolution and watch the updates. It will
// be called when dialing. // be called when dialing.
Start(target string) error Start(target string, config BalancerConfig) error
// Up informs the Balancer that gRPC has a connection to the server at // Up informs the Balancer that gRPC has a connection to the server at
// addr. It returns down which is called once the connection to addr gets // addr. It returns down which is called once the connection to addr gets
// lost or closed. // lost or closed.
// TODO: It is not clear how to construct and take advantage the meaningful error // TODO: It is not clear how to construct and take advantage of the meaningful error
// parameter for down. Need realistic demands to guide. // parameter for down. Need realistic demands to guide.
Up(addr Address) (down func(error)) Up(addr Address) (down func(error))
// Get gets the address of a server for the RPC corresponding to ctx. // Get gets the address of a server for the RPC corresponding to ctx.
@ -205,7 +214,12 @@ func (rr *roundRobin) watchAddrUpdates() error {
return nil return nil
} }
func (rr *roundRobin) Start(target string) error { func (rr *roundRobin) Start(target string, config BalancerConfig) error {
rr.mu.Lock()
defer rr.mu.Unlock()
if rr.done {
return ErrClientConnClosing
}
if rr.r == nil { if rr.r == nil {
// If there is no name resolver installed, it is not needed to // If there is no name resolver installed, it is not needed to
// do name resolution. In this case, target is added into rr.addrs // do name resolution. In this case, target is added into rr.addrs

View File

@ -96,7 +96,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
} }
outBuf, err := encode(codec, args, compressor, cbuf) outBuf, err := encode(codec, args, compressor, cbuf)
if err != nil { if err != nil {
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) return nil, Errorf(codes.Internal, "grpc: %v", err)
} }
err = t.Write(stream, outBuf, opts) err = t.Write(stream, outBuf, opts)
// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
@ -112,7 +112,14 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
// Invoke sends the RPC request on the wire and returns after response is received. // Invoke sends the RPC request on the wire and returns after response is received.
// Invoke is called by generated code. Also users can call Invoke directly when it // Invoke is called by generated code. Also users can call Invoke directly when it
// is really needed in their use cases. // is really needed in their use cases.
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) error {
if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
}
return invoke(ctx, method, args, reply, cc, opts...)
}
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
c := defaultCallInfo c := defaultCallInfo
for _, o := range opts { for _, o := range opts {
if err := o.before(&c); err != nil { if err := o.before(&c); err != nil {

View File

@ -83,6 +83,8 @@ var (
// dialOptions configure a Dial call. dialOptions are set by the DialOption // dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial. // values passed to Dial.
type dialOptions struct { type dialOptions struct {
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
@ -215,19 +217,48 @@ func WithUserAgent(s string) DialOption {
} }
} }
// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs.
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
return func(o *dialOptions) {
o.unaryInt = f
}
}
// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs.
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
return func(o *dialOptions) {
o.streamInt = f
}
}
// Dial creates a client connection to the given target. // Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) { func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...) return DialContext(context.Background(), target, opts...)
} }
// DialContext creates a client connection to the given target // DialContext creates a client connection to the given target. ctx can be used to
// using the supplied context. // cancel or expire the pending connecting. Once this function returns, the
func DialContext(ctx context.Context, target string, opts ...DialOption) (*ClientConn, error) { // cancellation and expiration of ctx will be noop. Users should call ClientConn.Close
// to terminate all the pending operations after this function returns.
// This is the EXPERIMENTAL API.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
conns: make(map[Address]*addrConn), conns: make(map[Address]*addrConn),
} }
cc.ctx, cc.cancel = context.WithCancel(ctx) cc.ctx, cc.cancel = context.WithCancel(context.Background())
defer func() {
select {
case <-ctx.Done():
conn, err = nil, ctx.Err()
default:
}
if err != nil {
cc.Close()
}
}()
for _, opt := range opts { for _, opt := range opts {
opt(&cc.dopts) opt(&cc.dopts)
} }
@ -239,17 +270,34 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (*Clien
if cc.dopts.bs == nil { if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig cc.dopts.bs = DefaultBackoffConfig
} }
creds := cc.dopts.copts.TransportCredentials
var ( if creds != nil && creds.Info().ServerName != "" {
ok bool cc.authority = creds.Info().ServerName
addrs []Address } else {
) colonPos := strings.LastIndex(target, ":")
if colonPos == -1 {
colonPos = len(target)
}
cc.authority = target[:colonPos]
}
var ok bool
waitC := make(chan error, 1)
go func() {
var addrs []Address
if cc.dopts.balancer == nil { if cc.dopts.balancer == nil {
// Connect to target directly if balancer is nil. // Connect to target directly if balancer is nil.
addrs = append(addrs, Address{Addr: target}) addrs = append(addrs, Address{Addr: target})
} else { } else {
if err := cc.dopts.balancer.Start(target); err != nil { var credsClone credentials.TransportCredentials
return nil, err if creds != nil {
credsClone = creds.Clone()
}
config := BalancerConfig{
DialCreds: credsClone,
}
if err := cc.dopts.balancer.Start(target, config); err != nil {
waitC <- err
return
} }
ch := cc.dopts.balancer.Notify() ch := cc.dopts.balancer.Notify()
if ch == nil { if ch == nil {
@ -258,12 +306,11 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (*Clien
} else { } else {
addrs, ok = <-ch addrs, ok = <-ch
if !ok || len(addrs) == 0 { if !ok || len(addrs) == 0 {
return nil, errNoAddr waitC <- errNoAddr
return
} }
} }
} }
waitC := make(chan error, 1)
go func() {
for _, a := range addrs { for _, a := range addrs {
if err := cc.resetAddrConn(a, false, nil); err != nil { if err := cc.resetAddrConn(a, false, nil); err != nil {
waitC <- err waitC <- err
@ -277,16 +324,13 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (*Clien
timeoutCh = time.After(cc.dopts.timeout) timeoutCh = time.After(cc.dopts.timeout)
} }
select { select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-waitC: case err := <-waitC:
if err != nil { if err != nil {
cc.Close()
return nil, err return nil, err
} }
case <-cc.ctx.Done():
cc.Close()
return nil, cc.ctx.Err()
case <-timeoutCh: case <-timeoutCh:
cc.Close()
return nil, ErrClientConnTimeout return nil, ErrClientConnTimeout
} }
// If balancer is nil or balancer.Notify() is nil, ok will be false here. // If balancer is nil or balancer.Notify() is nil, ok will be false here.
@ -294,11 +338,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (*Clien
if ok { if ok {
go cc.lbWatcher() go cc.lbWatcher()
} }
colonPos := strings.LastIndex(target, ":")
if colonPos == -1 {
colonPos = len(target)
}
cc.authority = target[:colonPos]
return cc, nil return cc, nil
} }
@ -652,7 +691,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return err return err
} }
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr) grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %v", err, ac.addr)
ac.mu.Lock() ac.mu.Lock()
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.

View File

@ -40,6 +40,7 @@ package credentials // import "google.golang.org/grpc/credentials"
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -71,7 +72,7 @@ type PerRPCCredentials interface {
} }
// ProtocolInfo provides information regarding the gRPC wire protocol version, // ProtocolInfo provides information regarding the gRPC wire protocol version,
// security protocol, security protocol version in use, etc. // security protocol, security protocol version in use, server name, etc.
type ProtocolInfo struct { type ProtocolInfo struct {
// ProtocolVersion is the gRPC wire protocol version. // ProtocolVersion is the gRPC wire protocol version.
ProtocolVersion string ProtocolVersion string
@ -79,6 +80,8 @@ type ProtocolInfo struct {
SecurityProtocol string SecurityProtocol string
// SecurityVersion is the security protocol version. // SecurityVersion is the security protocol version.
SecurityVersion string SecurityVersion string
// ServerName is the user-configured server name.
ServerName string
} }
// AuthInfo defines the common interface for the auth information the users are interested in. // AuthInfo defines the common interface for the auth information the users are interested in.
@ -86,6 +89,12 @@ type AuthInfo interface {
AuthType() string AuthType() string
} }
var (
// ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
// and the caller should not close rawConn.
ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
)
// TransportCredentials defines the common interface for all the live gRPC wire // TransportCredentials defines the common interface for all the live gRPC wire
// protocols and supported transport security protocols (e.g., TLS, SSL). // protocols and supported transport security protocols (e.g., TLS, SSL).
type TransportCredentials interface { type TransportCredentials interface {
@ -100,6 +109,12 @@ type TransportCredentials interface {
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error) ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportCredentials. // Info provides the ProtocolInfo of this TransportCredentials.
Info() ProtocolInfo Info() ProtocolInfo
// Clone makes a copy of this TransportCredentials.
Clone() TransportCredentials
// OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server.
// gRPC internals also use it to override the virtual hosting name if it is set.
// It must be called before dialing. Currently, this is only used by grpclb.
OverrideServerName(string) error
} }
// TLSInfo contains the auth information for a TLS authenticated connection. // TLSInfo contains the auth information for a TLS authenticated connection.
@ -123,19 +138,10 @@ func (c tlsCreds) Info() ProtocolInfo {
return ProtocolInfo{ return ProtocolInfo{
SecurityProtocol: "tls", SecurityProtocol: "tls",
SecurityVersion: "1.2", SecurityVersion: "1.2",
ServerName: c.config.ServerName,
} }
} }
// GetRequestMetadata returns nil, nil since TLS credentials does not have
// metadata.
func (c *tlsCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return nil, nil
}
func (c *tlsCreds) RequireTransportSecurity() bool {
return true
}
func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) { func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints // use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := cloneTLSConfig(c.config) cfg := cloneTLSConfig(c.config)
@ -172,6 +178,15 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
return conn, TLSInfo{conn.ConnectionState()}, nil return conn, TLSInfo{conn.ConnectionState()}, nil
} }
func (c *tlsCreds) Clone() TransportCredentials {
return NewTLS(c.config)
}
func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
c.config.ServerName = serverNameOverride
return nil
}
// NewTLS uses c to construct a TransportCredentials based on TLS. // NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials { func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{cloneTLSConfig(c)} tc := &tlsCreds{cloneTLSConfig(c)}
@ -180,12 +195,16 @@ func NewTLS(c *tls.Config) TransportCredentials {
} }
// NewClientTLSFromCert constructs a TLS from the input certificate for client. // NewClientTLSFromCert constructs a TLS from the input certificate for client.
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials { // serverNameOverride is for testing only. If set to a non empty string,
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}) // it will override the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp})
} }
// NewClientTLSFromFile constructs a TLS from the input certificate file for client. // NewClientTLSFromFile constructs a TLS from the input certificate file for client.
func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) { // serverNameOverride is for testing only. If set to a non empty string,
// it will override the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) {
b, err := ioutil.ReadFile(certFile) b, err := ioutil.ReadFile(certFile)
if err != nil { if err != nil {
return nil, err return nil, err
@ -194,7 +213,7 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, er
if !cp.AppendCertsFromPEM(b) { if !cp.AppendCertsFromPEM(b) {
return nil, fmt.Errorf("credentials: failed to append certificates") return nil, fmt.Errorf("credentials: failed to append certificates")
} }
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
} }
// NewServerTLSFromCert constructs a TLS from the input certificate for server. // NewServerTLSFromCert constructs a TLS from the input certificate for server.

View File

@ -37,6 +37,22 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
) )
// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs.
type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error
// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC
// and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API.
type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
// Streamer is called by StreamClientInterceptor to create a ClientStream.
type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error)
// StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O
// operations. streamer is the handlder to create a ClientStream and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API.
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)
// UnaryServerInfo consists of various information about a unary RPC on // UnaryServerInfo consists of various information about a unary RPC on
// server side. All per-rpc information may be mutated by the interceptor. // server side. All per-rpc information may be mutated by the interceptor.
type UnaryServerInfo struct { type UnaryServerInfo struct {

View File

@ -117,10 +117,17 @@ func (md MD) Len() int {
// Copy returns a copy of md. // Copy returns a copy of md.
func (md MD) Copy() MD { func (md MD) Copy() MD {
return Join(md)
}
// Join joins any number of MDs into a single MD.
// The order of values for each key is determined by the order in which
// the MDs containing those values are presented to Join.
func Join(mds ...MD) MD {
out := MD{} out := MD{}
for _, md := range mds {
for k, v := range md { for k, v := range md {
for _, i := range v { out[k] = append(out[k], v...)
out[k] = append(out[k], i)
} }
} }
return out return out

View File

@ -303,10 +303,10 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
case compressionNone: case compressionNone:
case compressionMade: case compressionMade:
if dc == nil || recvCompress != dc.Type() { if dc == nil || recvCompress != dc.Type() {
return transport.StreamErrorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) return Errorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
} }
default: default:
return transport.StreamErrorf(codes.Internal, "grpc: received unexpected payload format %d", pf) return Errorf(codes.Internal, "grpc: received unexpected payload format %d", pf)
} }
return nil return nil
} }

View File

@ -324,7 +324,7 @@ func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credenti
// Serve accepts incoming connections on the listener lis, creating a new // Serve accepts incoming connections on the listener lis, creating a new
// ServerTransport and service goroutine for each. The service goroutines // ServerTransport and service goroutine for each. The service goroutines
// read gRPC requests and then call the registered handlers to reply to them. // read gRPC requests and then call the registered handlers to reply to them.
// Service returns when lis.Accept fails. lis will be closed when // Serve returns when lis.Accept fails. lis will be closed when
// this method returns. // this method returns.
func (s *Server) Serve(lis net.Listener) error { func (s *Server) Serve(lis net.Listener) error {
s.mu.Lock() s.mu.Lock()
@ -367,7 +367,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
s.mu.Unlock() s.mu.Unlock()
grpclog.Printf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) grpclog.Printf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
// If serverHandShake returns ErrConnDispatched, keep rawConn open.
if err != credentials.ErrConnDispatched {
rawConn.Close() rawConn.Close()
}
return return
} }
@ -544,7 +547,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err return err
} }
if err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF {
err = transport.StreamError{Code: codes.Internal, Desc: "io.ErrUnexpectedEOF"} err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
} }
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
@ -566,8 +569,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
switch err := err.(type) { switch err := err.(type) {
case transport.StreamError: case *rpcError:
if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil { if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
} }
default: default:
@ -870,25 +873,28 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
} }
stream, ok := transport.StreamFromContext(ctx) stream, ok := transport.StreamFromContext(ctx)
if !ok { if !ok {
return fmt.Errorf("grpc: failed to fetch the stream from the context %v", ctx) return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
} }
t := stream.ServerTransport() t := stream.ServerTransport()
if t == nil { if t == nil {
grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream) grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream)
} }
return t.WriteHeader(stream, md) if err := t.WriteHeader(stream, md); err != nil {
return toRPCErr(err)
}
return nil
} }
// SetTrailer sets the trailer metadata that will be sent when an RPC returns. // SetTrailer sets the trailer metadata that will be sent when an RPC returns.
// It may be called at most once from a unary RPC handler. The ctx is the RPC // When called more than once, all the provided metadata will be merged.
// handler's Context or one derived from it. // The ctx is the RPC handler's Context or one derived from it.
func SetTrailer(ctx context.Context, md metadata.MD) error { func SetTrailer(ctx context.Context, md metadata.MD) error {
if md.Len() == 0 { if md.Len() == 0 {
return nil return nil
} }
stream, ok := transport.StreamFromContext(ctx) stream, ok := transport.StreamFromContext(ctx)
if !ok { if !ok {
return fmt.Errorf("grpc: failed to fetch the stream from the context %v", ctx) return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
} }
return stream.SetTrailer(md) return stream.SetTrailer(md)
} }

View File

@ -97,7 +97,14 @@ type ClientStream interface {
// NewClientStream creates a new Stream for the client side. This is called // NewClientStream creates a new Stream for the client side. This is called
// by generated code. // by generated code.
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
if cc.dopts.streamInt != nil {
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
}
return newClientStream(ctx, desc, cc, method, opts...)
}
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var ( var (
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
@ -296,7 +303,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
} }
}() }()
if err != nil { if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: %v", err) return Errorf(codes.Internal, "grpc: %v", err)
} }
return cs.t.Write(cs.s, out, &transport.Options{Last: false}) return cs.t.Write(cs.s, out, &transport.Options{Last: false})
} }
@ -407,8 +414,8 @@ type ServerStream interface {
// after SendProto. It fails if called multiple times or if // after SendProto. It fails if called multiple times or if
// called after SendProto. // called after SendProto.
SendHeader(metadata.MD) error SendHeader(metadata.MD) error
// SetTrailer sets the trailer metadata which will be sent with the // SetTrailer sets the trailer metadata which will be sent with the RPC status.
// RPC status. // When called more than once, all the provided metadata will be merged.
SetTrailer(metadata.MD) SetTrailer(metadata.MD)
Stream Stream
} }
@ -468,10 +475,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
} }
}() }()
if err != nil { if err != nil {
err = transport.StreamErrorf(codes.Internal, "grpc: %v", err) err = Errorf(codes.Internal, "grpc: %v", err)
return err return err
} }
return ss.t.Write(ss.s, out, &transport.Options{Last: false}) if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
return toRPCErr(err)
}
return nil
} }
func (ss *serverStream) RecvMsg(m interface{}) (err error) { func (ss *serverStream) RecvMsg(m interface{}) (err error) {
@ -489,5 +499,14 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock() ss.mu.Unlock()
} }
}() }()
return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize) if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize); err != nil {
if err == io.EOF {
return err
}
if err == io.ErrUnexpectedEOF {
err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
}
return toRPCErr(err)
}
return nil
} }

View File

@ -85,7 +85,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
if v := r.Header.Get("grpc-timeout"); v != "" { if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := decodeTimeout(v) to, err := decodeTimeout(v)
if err != nil { if err != nil {
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err) return nil, streamErrorf(codes.Internal, "malformed time-out: %v", err)
} }
st.timeoutSet = true st.timeoutSet = true
st.timeout = to st.timeout = to
@ -393,5 +393,5 @@ func mapRecvMsgError(err error) error {
} }
} }
} }
return ConnectionError{Desc: err.Error()} return connectionErrorf(true, err, err.Error())
} }

View File

@ -114,14 +114,42 @@ func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Contex
return dialContext(ctx, "tcp", addr) return dialContext(ctx, "tcp", addr)
} }
func isTemporary(err error) bool {
switch err {
case io.EOF:
// Connection closures may be resolved upon retry, and are thus
// treated as temporary.
return true
case context.DeadlineExceeded:
// In Go 1.7, context.DeadlineExceeded implements Timeout(), and this
// special case is not needed. Until then, we need to keep this
// clause.
return true
}
switch err := err.(type) {
case interface {
Temporary() bool
}:
return err.Temporary()
case interface {
Timeout() bool
}:
// Timeouts may be resolved upon retry, and are thus treated as
// temporary.
return err.Timeout()
}
return false
}
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction // and starts to receive messages on it. Non-nil error returns if construction
// fails. // fails.
func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) { func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
scheme := "http" scheme := "http"
conn, connErr := dial(opts.Dialer, ctx, addr) conn, err := dial(opts.Dialer, ctx, addr)
if connErr != nil { if err != nil {
return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr) return nil, connectionErrorf(true, err, "transport: %v", err)
} }
// Any further errors will close the underlying connection // Any further errors will close the underlying connection
defer func(conn net.Conn) { defer func(conn net.Conn) {
@ -132,12 +160,13 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl
var authInfo credentials.AuthInfo var authInfo credentials.AuthInfo
if creds := opts.TransportCredentials; creds != nil { if creds := opts.TransportCredentials; creds != nil {
scheme = "https" scheme = "https"
conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn) conn, authInfo, err = creds.ClientHandshake(ctx, addr, conn)
if err != nil {
// Credentials handshake errors are typically considered permanent
// to avoid retrying on e.g. bad certificates.
temp := isTemporary(err)
return nil, connectionErrorf(temp, err, "transport: %v", err)
} }
if connErr != nil {
// Credentials handshake error is not a temporary error (unless the error
// was the connection closing).
return nil, ConnectionErrorf(connErr == io.EOF, connErr, "transport: %v", connErr)
} }
ua := primaryUA ua := primaryUA
if opts.UserAgent != "" { if opts.UserAgent != "" {
@ -176,11 +205,11 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl
n, err := t.conn.Write(clientPreface) n, err := t.conn.Write(clientPreface)
if err != nil { if err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf(true, err, "transport: %v", err) return nil, connectionErrorf(true, err, "transport: %v", err)
} }
if n != len(clientPreface) { if n != len(clientPreface) {
t.Close() t.Close()
return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
} }
if initialWindowSize != defaultWindowSize { if initialWindowSize != defaultWindowSize {
err = t.framer.writeSettings(true, http2.Setting{ err = t.framer.writeSettings(true, http2.Setting{
@ -192,13 +221,13 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl
} }
if err != nil { if err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf(true, err, "transport: %v", err) return nil, connectionErrorf(true, err, "transport: %v", err)
} }
// Adjust the connection flow control window if needed. // Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf(true, err, "transport: %v", err) return nil, connectionErrorf(true, err, "transport: %v", err)
} }
} }
go t.controller() go t.controller()
@ -223,8 +252,10 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
s.windowHandler = func(n int) { s.windowHandler = func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
} }
// Make a stream be able to cancel the pending operations by itself. // The client side stream context should have exactly the same life cycle with the user provided context.
s.ctx, s.cancel = context.WithCancel(ctx) // That means, s.ctx should be read-only. And s.ctx is done iff ctx is done.
// So we use the original context here instead of creating a copy.
s.ctx = ctx
s.dec = &recvBufferReader{ s.dec = &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
goAway: s.goAway, goAway: s.goAway,
@ -236,16 +267,6 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// NewStream creates a stream and register it into the transport as "active" // NewStream creates a stream and register it into the transport as "active"
// streams. // streams.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
// Record the timeout value on the context.
var timeout time.Duration
if dl, ok := ctx.Deadline(); ok {
timeout = dl.Sub(time.Now())
}
select {
case <-ctx.Done():
return nil, ContextErr(ctx.Err())
default:
}
pr := &peer.Peer{ pr := &peer.Peer{
Addr: t.conn.RemoteAddr(), Addr: t.conn.RemoteAddr(),
} }
@ -266,12 +287,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
pos := strings.LastIndex(callHdr.Method, "/") pos := strings.LastIndex(callHdr.Method, "/")
if pos == -1 { if pos == -1 {
return nil, StreamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method) return nil, streamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method)
} }
audience := "https://" + callHdr.Host + port + callHdr.Method[:pos] audience := "https://" + callHdr.Host + port + callHdr.Method[:pos]
data, err := c.GetRequestMetadata(ctx, audience) data, err := c.GetRequestMetadata(ctx, audience)
if err != nil { if err != nil {
return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err) return nil, streamErrorf(codes.InvalidArgument, "transport: %v", err)
} }
for k, v := range data { for k, v := range data {
authData[k] = v authData[k] = v
@ -352,9 +373,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if callHdr.SendCompress != "" { if callHdr.SendCompress != "" {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
} }
if timeout > 0 { if dl, ok := ctx.Deadline(); ok {
// Send out timeout regardless its value. The server can detect timeout context by itself.
timeout := dl.Sub(time.Now())
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
} }
for k, v := range authData { for k, v := range authData {
// Capital header names are illegal in HTTP/2. // Capital header names are illegal in HTTP/2.
k = strings.ToLower(k) k = strings.ToLower(k)
@ -408,7 +432,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
if err != nil { if err != nil {
t.notifyError(err) t.notifyError(err)
return nil, ConnectionErrorf(true, err, "transport: %v", err) return nil, connectionErrorf(true, err, "transport: %v", err)
} }
} }
t.writableChan <- 0 t.writableChan <- 0
@ -454,7 +478,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
} }
s.state = streamDone s.state = streamDone
s.mu.Unlock() s.mu.Unlock()
if _, ok := err.(StreamError); ok { if se, ok := err.(StreamError); ok && se.Code != codes.DeadlineExceeded {
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel})
} }
} }
@ -622,7 +646,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// invoked. // invoked.
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
t.notifyError(err) t.notifyError(err)
return ConnectionErrorf(true, err, "transport: %v", err) return connectionErrorf(true, err, "transport: %v", err)
} }
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite() t.framer.flushWrite()
@ -670,7 +694,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
func (t *http2Client) handleData(f *http2.DataFrame) { func (t *http2Client) handleData(f *http2.DataFrame) {
size := len(f.Data()) size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil { if err := t.fc.onData(uint32(size)); err != nil {
t.notifyError(ConnectionErrorf(true, err, "%v", err)) t.notifyError(connectionErrorf(true, err, "%v", err))
return return
} }
// Select the right stream to dispatch. // Select the right stream to dispatch.
@ -776,7 +800,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
if t.state == reachable || t.state == draining { if t.state == reachable || t.state == draining {
if f.LastStreamID > 0 && f.LastStreamID%2 != 1 { if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
t.mu.Unlock() t.mu.Unlock()
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID))
return return
} }
select { select {
@ -785,7 +809,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// t.goAway has been closed (i.e.,multiple GoAways). // t.goAway has been closed (i.e.,multiple GoAways).
if id < f.LastStreamID { if id < f.LastStreamID {
t.mu.Unlock() t.mu.Unlock()
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
return return
} }
t.prevGoAwayID = id t.prevGoAwayID = id
@ -823,6 +847,12 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
state.processHeaderField(hf) state.processHeaderField(hf)
} }
if state.err != nil { if state.err != nil {
s.mu.Lock()
if !s.headerDone {
close(s.headerChan)
s.headerDone = true
}
s.mu.Unlock()
s.write(recvMsg{err: state.err}) s.write(recvMsg{err: state.err})
// Something wrong. Stops reading even when there is remaining. // Something wrong. Stops reading even when there is remaining.
return return
@ -900,7 +930,7 @@ func (t *http2Client) reader() {
t.mu.Unlock() t.mu.Unlock()
if s != nil { if s != nil {
// use error detail to provide better err message // use error detail to provide better err message
handleMalformedHTTP2(s, StreamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.errorDetail())) handleMalformedHTTP2(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.errorDetail()))
} }
continue continue
} else { } else {

View File

@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
Val: uint32(initialWindowSize)}) Val: uint32(initialWindowSize)})
} }
if err := framer.writeSettings(true, settings...); err != nil { if err := framer.writeSettings(true, settings...); err != nil {
return nil, ConnectionErrorf(true, err, "transport: %v", err) return nil, connectionErrorf(true, err, "transport: %v", err)
} }
// Adjust the connection flow control window if needed. // Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := framer.writeWindowUpdate(true, 0, delta); err != nil { if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
return nil, ConnectionErrorf(true, err, "transport: %v", err) return nil, connectionErrorf(true, err, "transport: %v", err)
} }
} }
var buf bytes.Buffer var buf bytes.Buffer
@ -448,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
} }
if err != nil { if err != nil {
t.Close() t.Close()
return ConnectionErrorf(true, err, "transport: %v", err) return connectionErrorf(true, err, "transport: %v", err)
} }
} }
return nil return nil
@ -544,7 +544,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
s.mu.Lock() s.mu.Lock()
if s.state == streamDone { if s.state == streamDone {
s.mu.Unlock() s.mu.Unlock()
return StreamErrorf(codes.Unknown, "the stream has been done") return streamErrorf(codes.Unknown, "the stream has been done")
} }
if !s.headerOk { if !s.headerOk {
writeHeaderFrame = true writeHeaderFrame = true
@ -568,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
if err := t.framer.writeHeaders(false, p); err != nil { if err := t.framer.writeHeaders(false, p); err != nil {
t.Close() t.Close()
return ConnectionErrorf(true, err, "transport: %v", err) return connectionErrorf(true, err, "transport: %v", err)
} }
t.writableChan <- 0 t.writableChan <- 0
} }
@ -642,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
t.Close() t.Close()
return ConnectionErrorf(true, err, "transport: %v", err) return connectionErrorf(true, err, "transport: %v", err)
} }
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite() t.framer.flushWrite()

View File

@ -53,7 +53,7 @@ import (
const ( const (
// The primary user agent // The primary user agent
primaryUA = "grpc-go/0.11" primaryUA = "grpc-go/1.0"
// http2MaxFrameLen specifies the max length of a HTTP2 frame. // http2MaxFrameLen specifies the max length of a HTTP2 frame.
http2MaxFrameLen = 16384 // 16KB frame http2MaxFrameLen = 16384 // 16KB frame
// http://http2.github.io/http2-spec/#SettingValues // http://http2.github.io/http2-spec/#SettingValues
@ -162,7 +162,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
switch f.Name { switch f.Name {
case "content-type": case "content-type":
if !validContentType(f.Value) { if !validContentType(f.Value) {
d.setErr(StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value)) d.setErr(streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value))
return return
} }
case "grpc-encoding": case "grpc-encoding":
@ -170,7 +170,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
case "grpc-status": case "grpc-status":
code, err := strconv.Atoi(f.Value) code, err := strconv.Atoi(f.Value)
if err != nil { if err != nil {
d.setErr(StreamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err)) d.setErr(streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err))
return return
} }
d.statusCode = codes.Code(code) d.statusCode = codes.Code(code)
@ -181,7 +181,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
var err error var err error
d.timeout, err = decodeTimeout(f.Value) d.timeout, err = decodeTimeout(f.Value)
if err != nil { if err != nil {
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err)) d.setErr(streamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
return return
} }
case ":path": case ":path":
@ -253,6 +253,9 @@ func div(d, r time.Duration) int64 {
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. // TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
func encodeTimeout(t time.Duration) string { func encodeTimeout(t time.Duration) string {
if t <= 0 {
return "0n"
}
if d := div(t, time.Nanosecond); d <= maxTimeoutValue { if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n" return strconv.FormatInt(d, 10) + "n"
} }
@ -349,7 +352,7 @@ func decodeGrpcMessageUnchecked(msg string) string {
for i := 0; i < lenMsg; i++ { for i := 0; i < lenMsg; i++ {
c := msg[i] c := msg[i]
if c == percentByte && i+2 < lenMsg { if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8) parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
if err != nil { if err != nil {
buf.WriteByte(c) buf.WriteByte(c)
} else { } else {

View File

@ -39,7 +39,6 @@ package transport // import "google.golang.org/grpc/transport"
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -170,6 +169,7 @@ type Stream struct {
st ServerTransport st ServerTransport
// ctx is the associated context of the stream. // ctx is the associated context of the stream.
ctx context.Context ctx context.Context
// cancel is always nil for client side Stream.
cancel context.CancelFunc cancel context.CancelFunc
// done is closed when the final status arrives. // done is closed when the final status arrives.
done chan struct{} done chan struct{}
@ -286,19 +286,12 @@ func (s *Stream) StatusDesc() string {
return s.statusDesc return s.statusDesc
} }
// ErrIllegalTrailerSet indicates that the trailer has already been set or it
// is too late to do so.
var ErrIllegalTrailerSet = errors.New("transport: trailer has been set")
// SetTrailer sets the trailer metadata which will be sent with the RPC status // SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can only be called at most once. Server side only. // by the server. This can be called multiple times. Server side only.
func (s *Stream) SetTrailer(md metadata.MD) error { func (s *Stream) SetTrailer(md metadata.MD) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.trailer != nil { s.trailer = metadata.Join(s.trailer, md)
return ErrIllegalTrailerSet
}
s.trailer = md.Copy()
return nil return nil
} }
@ -476,16 +469,16 @@ type ServerTransport interface {
Drain() Drain()
} }
// StreamErrorf creates an StreamError with the specified error code and description. // streamErrorf creates an StreamError with the specified error code and description.
func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError { func streamErrorf(c codes.Code, format string, a ...interface{}) StreamError {
return StreamError{ return StreamError{
Code: c, Code: c,
Desc: fmt.Sprintf(format, a...), Desc: fmt.Sprintf(format, a...),
} }
} }
// ConnectionErrorf creates an ConnectionError with the specified error description. // connectionErrorf creates an ConnectionError with the specified error description.
func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError { func connectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError {
return ConnectionError{ return ConnectionError{
Desc: fmt.Sprintf(format, a...), Desc: fmt.Sprintf(format, a...),
temp: temp, temp: temp,
@ -522,10 +515,10 @@ func (e ConnectionError) Origin() error {
var ( var (
// ErrConnClosing indicates that the transport is closing. // ErrConnClosing indicates that the transport is closing.
ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true} ErrConnClosing = connectionErrorf(true, nil, "transport is closing")
// ErrStreamDrain indicates that the stream is rejected by the server because // ErrStreamDrain indicates that the stream is rejected by the server because
// the server stops accepting new RPCs. // the server stops accepting new RPCs.
ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs") ErrStreamDrain = streamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
) )
// StreamError is an error that only affects one stream within a connection. // StreamError is an error that only affects one stream within a connection.
@ -542,9 +535,9 @@ func (e StreamError) Error() string {
func ContextErr(err error) StreamError { func ContextErr(err error) StreamError {
switch err { switch err {
case context.DeadlineExceeded: case context.DeadlineExceeded:
return StreamErrorf(codes.DeadlineExceeded, "%v", err) return streamErrorf(codes.DeadlineExceeded, "%v", err)
case context.Canceled: case context.Canceled:
return StreamErrorf(codes.Canceled, "%v", err) return streamErrorf(codes.Canceled, "%v", err)
} }
panic(fmt.Sprintf("Unexpected error from context packet: %v", err)) panic(fmt.Sprintf("Unexpected error from context packet: %v", err))
} }