Merge pull request #8840 from gyuho/health-balancer

*: refactor clientv3 balancer, upgrade gRPC to v1.7.2
This commit is contained in:
Gyu-Ho Lee 2017-11-10 15:41:00 -08:00 committed by GitHub
commit b64c1bfce6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 3101 additions and 2261 deletions

View File

@ -1,439 +0,0 @@
// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package clientv3
import (
"context"
"net/url"
"strings"
"sync"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// ErrNoAddrAvilable is returned by Get() when the balancer does not have
// any active connection to endpoints at the time.
// This error is returned only when opts.BlockingWait is true.
var ErrNoAddrAvilable = status.Error(codes.Unavailable, "there is no address available")
type notifyMsg int
const (
notifyReset notifyMsg = iota
notifyNext
)
// simpleBalancer does the bare minimum to expose multiple eps
// to the grpc reconnection code path
type simpleBalancer struct {
// addrs are the client's endpoint addresses for grpc
addrs []grpc.Address
// eps holds the raw endpoints from the client
eps []string
// notifyCh notifies grpc of the set of addresses for connecting
notifyCh chan []grpc.Address
// readyc closes once the first connection is up
readyc chan struct{}
readyOnce sync.Once
// mu protects all fields below.
mu sync.RWMutex
// upc closes when pinAddr transitions from empty to non-empty or the balancer closes.
upc chan struct{}
// downc closes when grpc calls down() on pinAddr
downc chan struct{}
// stopc is closed to signal updateNotifyLoop should stop.
stopc chan struct{}
// donec closes when all goroutines are exited
donec chan struct{}
// updateAddrsC notifies updateNotifyLoop to update addrs.
updateAddrsC chan notifyMsg
// grpc issues TLS cert checks using the string passed into dial so
// that string must be the host. To recover the full scheme://host URL,
// have a map from hosts to the original endpoint.
hostPort2ep map[string]string
// pinAddr is the currently pinned address; set to the empty string on
// initialization and shutdown.
pinAddr string
closed bool
}
func newSimpleBalancer(eps []string) *simpleBalancer {
notifyCh := make(chan []grpc.Address)
addrs := eps2addrs(eps)
sb := &simpleBalancer{
addrs: addrs,
eps: eps,
notifyCh: notifyCh,
readyc: make(chan struct{}),
upc: make(chan struct{}),
stopc: make(chan struct{}),
downc: make(chan struct{}),
donec: make(chan struct{}),
updateAddrsC: make(chan notifyMsg),
hostPort2ep: getHostPort2ep(eps),
}
close(sb.downc)
go sb.updateNotifyLoop()
return sb
}
func (b *simpleBalancer) Start(target string, config grpc.BalancerConfig) error { return nil }
func (b *simpleBalancer) ConnectNotify() <-chan struct{} {
b.mu.Lock()
defer b.mu.Unlock()
return b.upc
}
func (b *simpleBalancer) ready() <-chan struct{} { return b.readyc }
func (b *simpleBalancer) endpoint(hostPort string) string {
b.mu.Lock()
defer b.mu.Unlock()
return b.hostPort2ep[hostPort]
}
func (b *simpleBalancer) endpoints() []string {
b.mu.RLock()
defer b.mu.RUnlock()
return b.eps
}
func (b *simpleBalancer) pinned() string {
b.mu.RLock()
defer b.mu.RUnlock()
return b.pinAddr
}
func getHostPort2ep(eps []string) map[string]string {
hm := make(map[string]string, len(eps))
for i := range eps {
_, host, _ := parseEndpoint(eps[i])
hm[host] = eps[i]
}
return hm
}
func (b *simpleBalancer) updateAddrs(eps ...string) {
np := getHostPort2ep(eps)
b.mu.Lock()
match := len(np) == len(b.hostPort2ep)
for k, v := range np {
if b.hostPort2ep[k] != v {
match = false
break
}
}
if match {
// same endpoints, so no need to update address
b.mu.Unlock()
return
}
b.hostPort2ep = np
b.addrs, b.eps = eps2addrs(eps), eps
// updating notifyCh can trigger new connections,
// only update addrs if all connections are down
// or addrs does not include pinAddr.
update := !hasAddr(b.addrs, b.pinAddr)
b.mu.Unlock()
if update {
select {
case b.updateAddrsC <- notifyNext:
case <-b.stopc:
}
}
}
func (b *simpleBalancer) next() {
b.mu.RLock()
downc := b.downc
b.mu.RUnlock()
select {
case b.updateAddrsC <- notifyNext:
case <-b.stopc:
}
// wait until disconnect so new RPCs are not issued on old connection
select {
case <-downc:
case <-b.stopc:
}
}
func hasAddr(addrs []grpc.Address, targetAddr string) bool {
for _, addr := range addrs {
if targetAddr == addr.Addr {
return true
}
}
return false
}
func (b *simpleBalancer) updateNotifyLoop() {
defer close(b.donec)
for {
b.mu.RLock()
upc, downc, addr := b.upc, b.downc, b.pinAddr
b.mu.RUnlock()
// downc or upc should be closed
select {
case <-downc:
downc = nil
default:
}
select {
case <-upc:
upc = nil
default:
}
switch {
case downc == nil && upc == nil:
// stale
select {
case <-b.stopc:
return
default:
}
case downc == nil:
b.notifyAddrs(notifyReset)
select {
case <-upc:
case msg := <-b.updateAddrsC:
b.notifyAddrs(msg)
case <-b.stopc:
return
}
case upc == nil:
select {
// close connections that are not the pinned address
case b.notifyCh <- []grpc.Address{{Addr: addr}}:
case <-downc:
case <-b.stopc:
return
}
select {
case <-downc:
b.notifyAddrs(notifyReset)
case msg := <-b.updateAddrsC:
b.notifyAddrs(msg)
case <-b.stopc:
return
}
}
}
}
func (b *simpleBalancer) notifyAddrs(msg notifyMsg) {
if msg == notifyNext {
select {
case b.notifyCh <- []grpc.Address{}:
case <-b.stopc:
return
}
}
b.mu.RLock()
addrs := b.addrs
pinAddr := b.pinAddr
downc := b.downc
b.mu.RUnlock()
var waitDown bool
if pinAddr != "" {
waitDown = true
for _, a := range addrs {
if a.Addr == pinAddr {
waitDown = false
}
}
}
select {
case b.notifyCh <- addrs:
if waitDown {
select {
case <-downc:
case <-b.stopc:
}
}
case <-b.stopc:
}
}
func (b *simpleBalancer) Up(addr grpc.Address) func(error) {
f, _ := b.up(addr)
return f
}
func (b *simpleBalancer) up(addr grpc.Address) (func(error), bool) {
b.mu.Lock()
defer b.mu.Unlock()
// gRPC might call Up after it called Close. We add this check
// to "fix" it up at application layer. Otherwise, will panic
// if b.upc is already closed.
if b.closed {
return func(err error) {}, false
}
// gRPC might call Up on a stale address.
// Prevent updating pinAddr with a stale address.
if !hasAddr(b.addrs, addr.Addr) {
return func(err error) {}, false
}
if b.pinAddr != "" {
if logger.V(4) {
logger.Infof("clientv3/balancer: %q is up but not pinned (already pinned %q)", addr.Addr, b.pinAddr)
}
return func(err error) {}, false
}
// notify waiting Get()s and pin first connected address
close(b.upc)
b.downc = make(chan struct{})
b.pinAddr = addr.Addr
if logger.V(4) {
logger.Infof("clientv3/balancer: pin %q", addr.Addr)
}
// notify client that a connection is up
b.readyOnce.Do(func() { close(b.readyc) })
return func(err error) {
b.mu.Lock()
b.upc = make(chan struct{})
close(b.downc)
b.pinAddr = ""
b.mu.Unlock()
if logger.V(4) {
logger.Infof("clientv3/balancer: unpin %q (%q)", addr.Addr, err.Error())
}
}, true
}
func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) {
var (
addr string
closed bool
)
// If opts.BlockingWait is false (for fail-fast RPCs), it should return
// an address it has notified via Notify immediately instead of blocking.
if !opts.BlockingWait {
b.mu.RLock()
closed = b.closed
addr = b.pinAddr
b.mu.RUnlock()
if closed {
return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
}
if addr == "" {
return grpc.Address{Addr: ""}, nil, ErrNoAddrAvilable
}
return grpc.Address{Addr: addr}, func() {}, nil
}
for {
b.mu.RLock()
ch := b.upc
b.mu.RUnlock()
select {
case <-ch:
case <-b.donec:
return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
case <-ctx.Done():
return grpc.Address{Addr: ""}, nil, ctx.Err()
}
b.mu.RLock()
closed = b.closed
addr = b.pinAddr
b.mu.RUnlock()
// Close() which sets b.closed = true can be called before Get(), Get() must exit if balancer is closed.
if closed {
return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
}
if addr != "" {
break
}
}
return grpc.Address{Addr: addr}, func() {}, nil
}
func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.notifyCh }
func (b *simpleBalancer) Close() error {
b.mu.Lock()
// In case gRPC calls close twice. TODO: remove the checking
// when we are sure that gRPC wont call close twice.
if b.closed {
b.mu.Unlock()
<-b.donec
return nil
}
b.closed = true
close(b.stopc)
b.pinAddr = ""
// In the case of following scenario:
// 1. upc is not closed; no pinned address
// 2. client issues an RPC, calling invoke(), which calls Get(), enters for loop, blocks
// 3. client.conn.Close() calls balancer.Close(); closed = true
// 4. for loop in Get() never exits since ctx is the context passed in by the client and may not be canceled
// we must close upc so Get() exits from blocking on upc
select {
case <-b.upc:
default:
// terminate all waiting Get()s
close(b.upc)
}
b.mu.Unlock()
// wait for updateNotifyLoop to finish
<-b.donec
close(b.notifyCh)
return nil
}
func getHost(ep string) string {
url, uerr := url.Parse(ep)
if uerr != nil || !strings.Contains(ep, "://") {
return ep
}
return url.Host
}
func eps2addrs(eps []string) []grpc.Address {
addrs := make([]grpc.Address, len(eps))
for i := range eps {
addrs[i].Addr = getHost(eps[i])
}
return addrs
}

View File

@ -121,6 +121,19 @@ func (c *Client) SetEndpoints(eps ...string) {
c.cfg.Endpoints = eps c.cfg.Endpoints = eps
c.mu.Unlock() c.mu.Unlock()
c.balancer.updateAddrs(eps...) c.balancer.updateAddrs(eps...)
// updating notifyCh can trigger new connections,
// need update addrs if all connections are down
// or addrs does not include pinAddr.
c.balancer.mu.RLock()
update := !hasAddr(c.balancer.addrs, c.balancer.pinAddr)
c.balancer.mu.RUnlock()
if update {
select {
case c.balancer.updateAddrsC <- notifyNext:
case <-c.balancer.stopc:
}
}
} }
// Sync synchronizes client's endpoints with the known endpoints from the etcd membership. // Sync synchronizes client's endpoints with the known endpoints from the etcd membership.
@ -378,9 +391,9 @@ func newClient(cfg *Config) (*Client, error) {
client.Password = cfg.Password client.Password = cfg.Password
} }
sb := newSimpleBalancer(cfg.Endpoints) client.balancer = newHealthBalancer(cfg.Endpoints, cfg.DialTimeout, func(ep string) (bool, error) {
hc := func(ep string) (bool, error) { return grpcHealthCheck(client, ep) } return grpcHealthCheck(client, ep)
client.balancer = newHealthBalancer(sb, cfg.DialTimeout, hc) })
// use Endpoints[0] so that for https:// without any tls config given, then // use Endpoints[0] so that for https:// without any tls config given, then
// grpc will assume the certificate server name is the endpoint host. // grpc will assume the certificate server name is the endpoint host.

View File

@ -16,6 +16,9 @@ package clientv3
import ( import (
"context" "context"
"errors"
"net/url"
"strings"
"sync" "sync"
"time" "time"
@ -25,207 +28,549 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
const minHealthRetryDuration = 3 * time.Second const (
const unknownService = "unknown service grpc.health.v1.Health" minHealthRetryDuration = 3 * time.Second
unknownService = "unknown service grpc.health.v1.Health"
)
// ErrNoAddrAvilable is returned by Get() when the balancer does not have
// any active connection to endpoints at the time.
// This error is returned only when opts.BlockingWait is true.
var ErrNoAddrAvilable = status.Error(codes.Unavailable, "there is no address available")
type healthCheckFunc func(ep string) (bool, error) type healthCheckFunc func(ep string) (bool, error)
// healthBalancer wraps a balancer so that it uses health checking type notifyMsg int
// to choose its endpoints.
const (
notifyReset notifyMsg = iota
notifyNext
)
// healthBalancer does the bare minimum to expose multiple eps
// to the grpc reconnection code path
type healthBalancer struct { type healthBalancer struct {
*simpleBalancer // addrs are the client's endpoint addresses for grpc
addrs []grpc.Address
// eps holds the raw endpoints from the client
eps []string
// notifyCh notifies grpc of the set of addresses for connecting
notifyCh chan []grpc.Address
// readyc closes once the first connection is up
readyc chan struct{}
readyOnce sync.Once
// healthCheck checks an endpoint's health. // healthCheck checks an endpoint's health.
healthCheck healthCheckFunc healthCheck healthCheckFunc
healthCheckTimeout time.Duration healthCheckTimeout time.Duration
// mu protects addrs, eps, unhealthy map, and stopc. unhealthyMu sync.RWMutex
unhealthyHostPorts map[string]time.Time
// mu protects all fields below.
mu sync.RWMutex mu sync.RWMutex
// addrs stores all grpc addresses associated with the balancer. // upc closes when pinAddr transitions from empty to non-empty or the balancer closes.
addrs []grpc.Address upc chan struct{}
// eps stores all client endpoints // downc closes when grpc calls down() on pinAddr
eps []string downc chan struct{}
// unhealthy tracks the last unhealthy time of endpoints.
unhealthy map[string]time.Time
// stopc is closed to signal updateNotifyLoop should stop.
stopc chan struct{} stopc chan struct{}
stopOnce sync.Once stopOnce sync.Once
wg sync.WaitGroup
// donec closes when all goroutines are exited
donec chan struct{}
// updateAddrsC notifies updateNotifyLoop to update addrs.
updateAddrsC chan notifyMsg
// grpc issues TLS cert checks using the string passed into dial so
// that string must be the host. To recover the full scheme://host URL,
// have a map from hosts to the original endpoint.
hostPort2ep map[string]string hostPort2ep map[string]string
wg sync.WaitGroup // pinAddr is the currently pinned address; set to the empty string on
// initialization and shutdown.
pinAddr string
closed bool
} }
func newHealthBalancer(b *simpleBalancer, timeout time.Duration, hc healthCheckFunc) *healthBalancer { func newHealthBalancer(eps []string, timeout time.Duration, hc healthCheckFunc) *healthBalancer {
notifyCh := make(chan []grpc.Address)
addrs := eps2addrs(eps)
hb := &healthBalancer{ hb := &healthBalancer{
simpleBalancer: b, addrs: addrs,
eps: eps,
notifyCh: notifyCh,
readyc: make(chan struct{}),
healthCheck: hc, healthCheck: hc,
eps: b.endpoints(), unhealthyHostPorts: make(map[string]time.Time),
addrs: eps2addrs(b.endpoints()), upc: make(chan struct{}),
hostPort2ep: getHostPort2ep(b.endpoints()),
unhealthy: make(map[string]time.Time),
stopc: make(chan struct{}), stopc: make(chan struct{}),
downc: make(chan struct{}),
donec: make(chan struct{}),
updateAddrsC: make(chan notifyMsg),
hostPort2ep: getHostPort2ep(eps),
} }
if timeout < minHealthRetryDuration { if timeout < minHealthRetryDuration {
timeout = minHealthRetryDuration timeout = minHealthRetryDuration
} }
hb.healthCheckTimeout = timeout hb.healthCheckTimeout = timeout
close(hb.downc)
go hb.updateNotifyLoop()
hb.wg.Add(1) hb.wg.Add(1)
go func() { go func() {
defer hb.wg.Done() defer hb.wg.Done()
hb.updateUnhealthy(timeout) hb.updateUnhealthy()
}() }()
return hb return hb
} }
func (hb *healthBalancer) Up(addr grpc.Address) func(error) { func (b *healthBalancer) Start(target string, config grpc.BalancerConfig) error { return nil }
f, used := hb.up(addr)
if !used { func (b *healthBalancer) ConnectNotify() <-chan struct{} {
return f b.mu.Lock()
defer b.mu.Unlock()
return b.upc
}
func (b *healthBalancer) ready() <-chan struct{} { return b.readyc }
func (b *healthBalancer) endpoint(hostPort string) string {
b.mu.RLock()
defer b.mu.RUnlock()
return b.hostPort2ep[hostPort]
}
func (b *healthBalancer) pinned() string {
b.mu.RLock()
defer b.mu.RUnlock()
return b.pinAddr
}
func (b *healthBalancer) hostPortError(hostPort string, err error) {
if b.endpoint(hostPort) == "" {
if logger.V(4) {
logger.Infof("clientv3/balancer: %q is stale (skip marking as unhealthy on %q)", hostPort, err.Error())
} }
return func(err error) { return
// If connected to a black hole endpoint or a killed server, the gRPC ping }
// timeout will induce a network I/O error, and retrying until success;
// finding healthy endpoint on retry could take several timeouts and redials. b.unhealthyMu.Lock()
// To avoid wasting retries, gray-list unhealthy endpoints. b.unhealthyHostPorts[hostPort] = time.Now()
hb.hostPortError(addr.Addr, err) b.unhealthyMu.Unlock()
f(err) if logger.V(4) {
logger.Infof("clientv3/balancer: %q is marked unhealthy (%q)", hostPort, err.Error())
} }
} }
func (hb *healthBalancer) up(addr grpc.Address) (func(error), bool) { func (b *healthBalancer) removeUnhealthy(hostPort, msg string) {
if !hb.mayPin(addr) { if b.endpoint(hostPort) == "" {
return func(err error) {}, false if logger.V(4) {
logger.Infof("clientv3/balancer: %q was not in unhealthy (%q)", hostPort, msg)
}
return
}
b.unhealthyMu.Lock()
delete(b.unhealthyHostPorts, hostPort)
b.unhealthyMu.Unlock()
if logger.V(4) {
logger.Infof("clientv3/balancer: %q is removed from unhealthy (%q)", hostPort, msg)
} }
return hb.simpleBalancer.up(addr)
} }
func (hb *healthBalancer) Close() error { func (b *healthBalancer) countUnhealthy() (count int) {
hb.stopOnce.Do(func() { close(hb.stopc) }) b.unhealthyMu.RLock()
hb.wg.Wait() count = len(b.unhealthyHostPorts)
return hb.simpleBalancer.Close() b.unhealthyMu.RUnlock()
return count
} }
func (hb *healthBalancer) updateAddrs(eps ...string) { func (b *healthBalancer) isUnhealthy(hostPort string) (unhealthy bool) {
addrs, hostPort2ep := eps2addrs(eps), getHostPort2ep(eps) b.unhealthyMu.RLock()
hb.mu.Lock() _, unhealthy = b.unhealthyHostPorts[hostPort]
hb.addrs, hb.eps, hb.hostPort2ep = addrs, eps, hostPort2ep b.unhealthyMu.RUnlock()
hb.unhealthy = make(map[string]time.Time) return unhealthy
hb.mu.Unlock()
hb.simpleBalancer.updateAddrs(eps...)
} }
func (hb *healthBalancer) endpoint(host string) string { func (b *healthBalancer) cleanupUnhealthy() {
hb.mu.RLock() b.unhealthyMu.Lock()
defer hb.mu.RUnlock() for k, v := range b.unhealthyHostPorts {
return hb.hostPort2ep[host] if time.Since(v) > b.healthCheckTimeout {
delete(b.unhealthyHostPorts, k)
if logger.V(4) {
logger.Infof("clientv3/balancer: removed %q from unhealthy after %v", k, b.healthCheckTimeout)
}
}
}
b.unhealthyMu.Unlock()
} }
func (hb *healthBalancer) endpoints() []string { func (b *healthBalancer) liveAddrs() ([]grpc.Address, map[string]struct{}) {
hb.mu.RLock() unhealthyCnt := b.countUnhealthy()
defer hb.mu.RUnlock()
return hb.eps b.mu.RLock()
defer b.mu.RUnlock()
hbAddrs := b.addrs
if len(b.addrs) == 1 || unhealthyCnt == 0 || unhealthyCnt == len(b.addrs) {
liveHostPorts := make(map[string]struct{}, len(b.hostPort2ep))
for k := range b.hostPort2ep {
liveHostPorts[k] = struct{}{}
}
return hbAddrs, liveHostPorts
}
addrs := make([]grpc.Address, 0, len(b.addrs)-unhealthyCnt)
liveHostPorts := make(map[string]struct{}, len(addrs))
for _, addr := range b.addrs {
if !b.isUnhealthy(addr.Addr) {
addrs = append(addrs, addr)
liveHostPorts[addr.Addr] = struct{}{}
}
}
return addrs, liveHostPorts
} }
func (hb *healthBalancer) updateUnhealthy(timeout time.Duration) { func (b *healthBalancer) updateUnhealthy() {
for { for {
select { select {
case <-time.After(timeout): case <-time.After(b.healthCheckTimeout):
hb.mu.Lock() b.cleanupUnhealthy()
for k, v := range hb.unhealthy { pinned := b.pinned()
if time.Since(v) > timeout { if pinned == "" || b.isUnhealthy(pinned) {
delete(hb.unhealthy, k) select {
if logger.V(4) { case b.updateAddrsC <- notifyNext:
logger.Infof("clientv3/health-balancer: removes %q from unhealthy after %v", k, timeout) case <-b.stopc:
return
} }
} }
} case <-b.stopc:
hb.mu.Unlock()
eps := []string{}
for _, addr := range hb.liveAddrs() {
eps = append(eps, hb.endpoint(addr.Addr))
}
hb.simpleBalancer.updateAddrs(eps...)
case <-hb.stopc:
return return
} }
} }
} }
func (hb *healthBalancer) liveAddrs() []grpc.Address { func (b *healthBalancer) updateAddrs(eps ...string) {
hb.mu.RLock() np := getHostPort2ep(eps)
defer hb.mu.RUnlock()
hbAddrs := hb.addrs b.mu.Lock()
if len(hb.addrs) == 1 || len(hb.unhealthy) == 0 || len(hb.unhealthy) == len(hb.addrs) { defer b.mu.Unlock()
return hbAddrs
} match := len(np) == len(b.hostPort2ep)
addrs := make([]grpc.Address, 0, len(hb.addrs)-len(hb.unhealthy)) if match {
for _, addr := range hb.addrs { for k, v := range np {
if _, unhealthy := hb.unhealthy[addr.Addr]; !unhealthy { if b.hostPort2ep[k] != v {
addrs = append(addrs, addr) match = false
break
} }
} }
return addrs }
if match {
// same endpoints, so no need to update address
return
}
b.hostPort2ep = np
b.addrs, b.eps = eps2addrs(eps), eps
b.unhealthyMu.Lock()
b.unhealthyHostPorts = make(map[string]time.Time)
b.unhealthyMu.Unlock()
} }
func (hb *healthBalancer) hostPortError(hostPort string, err error) { func (b *healthBalancer) next() {
hb.mu.Lock() b.mu.RLock()
if _, ok := hb.hostPort2ep[hostPort]; ok { downc := b.downc
hb.unhealthy[hostPort] = time.Now() b.mu.RUnlock()
select {
case b.updateAddrsC <- notifyNext:
case <-b.stopc:
}
// wait until disconnect so new RPCs are not issued on old connection
select {
case <-downc:
case <-b.stopc:
}
}
func (b *healthBalancer) updateNotifyLoop() {
defer close(b.donec)
for {
b.mu.RLock()
upc, downc, addr := b.upc, b.downc, b.pinAddr
b.mu.RUnlock()
// downc or upc should be closed
select {
case <-downc:
downc = nil
default:
}
select {
case <-upc:
upc = nil
default:
}
switch {
case downc == nil && upc == nil:
// stale
select {
case <-b.stopc:
return
default:
}
case downc == nil:
b.notifyAddrs(notifyReset)
select {
case <-upc:
case msg := <-b.updateAddrsC:
b.notifyAddrs(msg)
case <-b.stopc:
return
}
case upc == nil:
select {
// close connections that are not the pinned address
case b.notifyCh <- []grpc.Address{{Addr: addr}}:
case <-downc:
case <-b.stopc:
return
}
select {
case <-downc:
b.notifyAddrs(notifyReset)
case msg := <-b.updateAddrsC:
b.notifyAddrs(msg)
case <-b.stopc:
return
}
}
}
}
func (b *healthBalancer) notifyAddrs(msg notifyMsg) {
if msg == notifyNext {
select {
case b.notifyCh <- []grpc.Address{}:
case <-b.stopc:
return
}
}
b.mu.RLock()
pinAddr := b.pinAddr
downc := b.downc
b.mu.RUnlock()
addrs, hostPorts := b.liveAddrs()
var waitDown bool
if pinAddr != "" {
_, ok := hostPorts[pinAddr]
waitDown = !ok
}
select {
case b.notifyCh <- addrs:
if waitDown {
select {
case <-downc:
case <-b.stopc:
}
}
case <-b.stopc:
}
}
func (b *healthBalancer) Up(addr grpc.Address) func(error) {
if !b.mayPin(addr) {
return func(err error) {}
}
b.mu.Lock()
defer b.mu.Unlock()
// gRPC might call Up after it called Close. We add this check
// to "fix" it up at application layer. Otherwise, will panic
// if b.upc is already closed.
if b.closed {
return func(err error) {}
}
// gRPC might call Up on a stale address.
// Prevent updating pinAddr with a stale address.
if !hasAddr(b.addrs, addr.Addr) {
return func(err error) {}
}
if b.pinAddr != "" {
if logger.V(4) { if logger.V(4) {
logger.Infof("clientv3/health-balancer: marking %q as unhealthy (%q)", hostPort, err.Error()) logger.Infof("clientv3/balancer: %q is up but not pinned (already pinned %q)", addr.Addr, b.pinAddr)
}
return func(err error) {}
}
// notify waiting Get()s and pin first connected address
close(b.upc)
b.downc = make(chan struct{})
b.pinAddr = addr.Addr
if logger.V(4) {
logger.Infof("clientv3/balancer: pin %q", addr.Addr)
}
// notify client that a connection is up
b.readyOnce.Do(func() { close(b.readyc) })
return func(err error) {
// If connected to a black hole endpoint or a killed server, the gRPC ping
// timeout will induce a network I/O error, and retrying until success;
// finding healthy endpoint on retry could take several timeouts and redials.
// To avoid wasting retries, gray-list unhealthy endpoints.
b.hostPortError(addr.Addr, err)
b.mu.Lock()
b.upc = make(chan struct{})
close(b.downc)
b.pinAddr = ""
b.mu.Unlock()
if logger.V(4) {
logger.Infof("clientv3/balancer: unpin %q (%q)", addr.Addr, err.Error())
} }
} }
hb.mu.Unlock()
} }
func (hb *healthBalancer) mayPin(addr grpc.Address) bool { func (b *healthBalancer) mayPin(addr grpc.Address) bool {
hb.mu.RLock() if b.endpoint(addr.Addr) == "" { // stale host:port
if _, ok := hb.hostPort2ep[addr.Addr]; !ok { // stale host:port
hb.mu.RUnlock()
return false return false
} }
skip := len(hb.addrs) == 1 || len(hb.unhealthy) == 0 || len(hb.addrs) == len(hb.unhealthy)
failedTime, bad := hb.unhealthy[addr.Addr] b.unhealthyMu.RLock()
dur := hb.healthCheckTimeout unhealthyCnt := len(b.unhealthyHostPorts)
hb.mu.RUnlock() failedTime, bad := b.unhealthyHostPorts[addr.Addr]
b.unhealthyMu.RUnlock()
b.mu.RLock()
skip := len(b.addrs) == 1 || unhealthyCnt == 0 || len(b.addrs) == unhealthyCnt
b.mu.RUnlock()
if skip || !bad { if skip || !bad {
return true return true
} }
// prevent isolated member's endpoint from being infinitely retried, as follows: // prevent isolated member's endpoint from being infinitely retried, as follows:
// 1. keepalive pings detects GoAway with http2.ErrCodeEnhanceYourCalm // 1. keepalive pings detects GoAway with http2.ErrCodeEnhanceYourCalm
// 2. balancer 'Up' unpins with grpc: failed with network I/O error // 2. balancer 'Up' unpins with grpc: failed with network I/O error
// 3. grpc-healthcheck still SERVING, thus retry to pin // 3. grpc-healthcheck still SERVING, thus retry to pin
// instead, return before grpc-healthcheck if failed within healthcheck timeout // instead, return before grpc-healthcheck if failed within healthcheck timeout
if elapsed := time.Since(failedTime); elapsed < dur { if elapsed := time.Since(failedTime); elapsed < b.healthCheckTimeout {
if logger.V(4) { if logger.V(4) {
logger.Infof("clientv3/health-balancer: %q is up but not pinned (failed %v ago, require minimum %v after failure)", addr.Addr, elapsed, dur) logger.Infof("clientv3/balancer: %q is up but not pinned (failed %v ago, require minimum %v after failure)", addr.Addr, elapsed, b.healthCheckTimeout)
} }
return false return false
} }
if ok, _ := hb.healthCheck(addr.Addr); ok {
hb.mu.Lock() if ok, _ := b.healthCheck(addr.Addr); ok {
delete(hb.unhealthy, addr.Addr) b.removeUnhealthy(addr.Addr, "health check success")
hb.mu.Unlock()
if logger.V(4) {
logger.Infof("clientv3/health-balancer: %q is healthy (health check success)", addr.Addr)
}
return true return true
} }
hb.mu.Lock()
hb.unhealthy[addr.Addr] = time.Now() b.hostPortError(addr.Addr, errors.New("health check failed"))
hb.mu.Unlock()
if logger.V(4) {
logger.Infof("clientv3/health-balancer: %q becomes unhealthy (health check failed)", addr.Addr)
}
return false return false
} }
func (b *healthBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) {
var (
addr string
closed bool
)
// If opts.BlockingWait is false (for fail-fast RPCs), it should return
// an address it has notified via Notify immediately instead of blocking.
if !opts.BlockingWait {
b.mu.RLock()
closed = b.closed
addr = b.pinAddr
b.mu.RUnlock()
if closed {
return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
}
if addr == "" {
return grpc.Address{Addr: ""}, nil, ErrNoAddrAvilable
}
return grpc.Address{Addr: addr}, func() {}, nil
}
for {
b.mu.RLock()
ch := b.upc
b.mu.RUnlock()
select {
case <-ch:
case <-b.donec:
return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
case <-ctx.Done():
return grpc.Address{Addr: ""}, nil, ctx.Err()
}
b.mu.RLock()
closed = b.closed
addr = b.pinAddr
b.mu.RUnlock()
// Close() which sets b.closed = true can be called before Get(), Get() must exit if balancer is closed.
if closed {
return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
}
if addr != "" {
break
}
}
return grpc.Address{Addr: addr}, func() {}, nil
}
func (b *healthBalancer) Notify() <-chan []grpc.Address { return b.notifyCh }
func (b *healthBalancer) Close() error {
b.mu.Lock()
// In case gRPC calls close twice. TODO: remove the checking
// when we are sure that gRPC wont call close twice.
if b.closed {
b.mu.Unlock()
<-b.donec
return nil
}
b.closed = true
b.stopOnce.Do(func() { close(b.stopc) })
b.pinAddr = ""
// In the case of following scenario:
// 1. upc is not closed; no pinned address
// 2. client issues an RPC, calling invoke(), which calls Get(), enters for loop, blocks
// 3. client.conn.Close() calls balancer.Close(); closed = true
// 4. for loop in Get() never exits since ctx is the context passed in by the client and may not be canceled
// we must close upc so Get() exits from blocking on upc
select {
case <-b.upc:
default:
// terminate all waiting Get()s
close(b.upc)
}
b.mu.Unlock()
b.wg.Wait()
// wait for updateNotifyLoop to finish
<-b.donec
close(b.notifyCh)
return nil
}
func grpcHealthCheck(client *Client, ep string) (bool, error) { func grpcHealthCheck(client *Client, ep string) (bool, error) {
conn, err := client.dial(ep) conn, err := client.dial(ep)
if err != nil { if err != nil {
@ -238,8 +583,7 @@ func grpcHealthCheck(client *Client, ep string) (bool, error) {
cancel() cancel()
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Unavailable { if s, ok := status.FromError(err); ok && s.Code() == codes.Unavailable {
if s.Message() == unknownService { if s.Message() == unknownService { // etcd < v3.3.0
// etcd < v3.3.0
return true, nil return true, nil
} }
} }
@ -247,3 +591,37 @@ func grpcHealthCheck(client *Client, ep string) (bool, error) {
} }
return resp.Status == healthpb.HealthCheckResponse_SERVING, nil return resp.Status == healthpb.HealthCheckResponse_SERVING, nil
} }
func hasAddr(addrs []grpc.Address, targetAddr string) bool {
for _, addr := range addrs {
if targetAddr == addr.Addr {
return true
}
}
return false
}
func getHost(ep string) string {
url, uerr := url.Parse(ep)
if uerr != nil || !strings.Contains(ep, "://") {
return ep
}
return url.Host
}
func eps2addrs(eps []string) []grpc.Address {
addrs := make([]grpc.Address, len(eps))
for i := range eps {
addrs[i].Addr = getHost(eps[i])
}
return addrs
}
func getHostPort2ep(eps []string) map[string]string {
hm := make(map[string]string, len(eps))
for i := range eps {
_, host, _ := parseEndpoint(eps[i])
hm[host] = eps[i]
}
return hm
}

View File

@ -1,4 +1,4 @@
// Copyright 2016 The etcd Authors // Copyright 2017 The etcd Authors
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -28,29 +28,27 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
) )
var ( var endpoints = []string{"localhost:2379", "localhost:22379", "localhost:32379"}
endpoints = []string{"localhost:2379", "localhost:22379", "localhost:32379"}
)
func TestBalancerGetUnblocking(t *testing.T) { func TestBalancerGetUnblocking(t *testing.T) {
sb := newSimpleBalancer(endpoints) hb := newHealthBalancer(endpoints, minHealthRetryDuration, func(string) (bool, error) { return true, nil })
defer sb.Close() defer hb.Close()
if addrs := <-sb.Notify(); len(addrs) != len(endpoints) { if addrs := <-hb.Notify(); len(addrs) != len(endpoints) {
t.Errorf("Initialize newSimpleBalancer should have triggered Notify() chan, but it didn't") t.Errorf("Initialize newHealthBalancer should have triggered Notify() chan, but it didn't")
} }
unblockingOpts := grpc.BalancerGetOptions{BlockingWait: false} unblockingOpts := grpc.BalancerGetOptions{BlockingWait: false}
_, _, err := sb.Get(context.Background(), unblockingOpts) _, _, err := hb.Get(context.Background(), unblockingOpts)
if err != ErrNoAddrAvilable { if err != ErrNoAddrAvilable {
t.Errorf("Get() with no up endpoints should return ErrNoAddrAvailable, got: %v", err) t.Errorf("Get() with no up endpoints should return ErrNoAddrAvailable, got: %v", err)
} }
down1 := sb.Up(grpc.Address{Addr: endpoints[1]}) down1 := hb.Up(grpc.Address{Addr: endpoints[1]})
if addrs := <-sb.Notify(); len(addrs) != 1 { if addrs := <-hb.Notify(); len(addrs) != 1 {
t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed") t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed")
} }
down2 := sb.Up(grpc.Address{Addr: endpoints[2]}) down2 := hb.Up(grpc.Address{Addr: endpoints[2]})
addrFirst, putFun, err := sb.Get(context.Background(), unblockingOpts) addrFirst, putFun, err := hb.Get(context.Background(), unblockingOpts)
if err != nil { if err != nil {
t.Errorf("Get() with up endpoints should success, got %v", err) t.Errorf("Get() with up endpoints should success, got %v", err)
} }
@ -60,32 +58,32 @@ func TestBalancerGetUnblocking(t *testing.T) {
if putFun == nil { if putFun == nil {
t.Errorf("Get() returned unexpected nil put function") t.Errorf("Get() returned unexpected nil put function")
} }
addrSecond, _, _ := sb.Get(context.Background(), unblockingOpts) addrSecond, _, _ := hb.Get(context.Background(), unblockingOpts)
if addrFirst.Addr != addrSecond.Addr { if addrFirst.Addr != addrSecond.Addr {
t.Errorf("Get() didn't return the same address as previous call, got %v and %v", addrFirst, addrSecond) t.Errorf("Get() didn't return the same address as previous call, got %v and %v", addrFirst, addrSecond)
} }
down1(errors.New("error")) down1(errors.New("error"))
if addrs := <-sb.Notify(); len(addrs) != len(endpoints) { if addrs := <-hb.Notify(); len(addrs) != len(endpoints)-1 { // we call down on one endpoint
t.Errorf("closing the only connection should triggered balancer to send the all endpoints via Notify chan so that we can establish a connection") t.Errorf("closing the only connection should triggered balancer to send the %d endpoints via Notify chan so that we can establish a connection", len(endpoints)-1)
} }
down2(errors.New("error")) down2(errors.New("error"))
_, _, err = sb.Get(context.Background(), unblockingOpts) _, _, err = hb.Get(context.Background(), unblockingOpts)
if err != ErrNoAddrAvilable { if err != ErrNoAddrAvilable {
t.Errorf("Get() with no up endpoints should return ErrNoAddrAvailable, got: %v", err) t.Errorf("Get() with no up endpoints should return ErrNoAddrAvailable, got: %v", err)
} }
} }
func TestBalancerGetBlocking(t *testing.T) { func TestBalancerGetBlocking(t *testing.T) {
sb := newSimpleBalancer(endpoints) hb := newHealthBalancer(endpoints, minHealthRetryDuration, func(string) (bool, error) { return true, nil })
defer sb.Close() defer hb.Close()
if addrs := <-sb.Notify(); len(addrs) != len(endpoints) { if addrs := <-hb.Notify(); len(addrs) != len(endpoints) {
t.Errorf("Initialize newSimpleBalancer should have triggered Notify() chan, but it didn't") t.Errorf("Initialize newHealthBalancer should have triggered Notify() chan, but it didn't")
} }
blockingOpts := grpc.BalancerGetOptions{BlockingWait: true} blockingOpts := grpc.BalancerGetOptions{BlockingWait: true}
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
_, _, err := sb.Get(ctx, blockingOpts) _, _, err := hb.Get(ctx, blockingOpts)
cancel() cancel()
if err != context.DeadlineExceeded { if err != context.DeadlineExceeded {
t.Errorf("Get() with no up endpoints should timeout, got %v", err) t.Errorf("Get() with no up endpoints should timeout, got %v", err)
@ -94,15 +92,15 @@ func TestBalancerGetBlocking(t *testing.T) {
downC := make(chan func(error), 1) downC := make(chan func(error), 1)
go func() { go func() {
// ensure sb.Up() will be called after sb.Get() to see if Up() releases blocking Get() // ensure hb.Up() will be called after hb.Get() to see if Up() releases blocking Get()
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
f := sb.Up(grpc.Address{Addr: endpoints[1]}) f := hb.Up(grpc.Address{Addr: endpoints[1]})
if addrs := <-sb.Notify(); len(addrs) != 1 { if addrs := <-hb.Notify(); len(addrs) != 1 {
t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed") t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed")
} }
downC <- f downC <- f
}() }()
addrFirst, putFun, err := sb.Get(context.Background(), blockingOpts) addrFirst, putFun, err := hb.Get(context.Background(), blockingOpts)
if err != nil { if err != nil {
t.Errorf("Get() with up endpoints should success, got %v", err) t.Errorf("Get() with up endpoints should success, got %v", err)
} }
@ -114,19 +112,19 @@ func TestBalancerGetBlocking(t *testing.T) {
} }
down1 := <-downC down1 := <-downC
down2 := sb.Up(grpc.Address{Addr: endpoints[2]}) down2 := hb.Up(grpc.Address{Addr: endpoints[2]})
addrSecond, _, _ := sb.Get(context.Background(), blockingOpts) addrSecond, _, _ := hb.Get(context.Background(), blockingOpts)
if addrFirst.Addr != addrSecond.Addr { if addrFirst.Addr != addrSecond.Addr {
t.Errorf("Get() didn't return the same address as previous call, got %v and %v", addrFirst, addrSecond) t.Errorf("Get() didn't return the same address as previous call, got %v and %v", addrFirst, addrSecond)
} }
down1(errors.New("error")) down1(errors.New("error"))
if addrs := <-sb.Notify(); len(addrs) != len(endpoints) { if addrs := <-hb.Notify(); len(addrs) != len(endpoints)-1 { // we call down on one endpoint
t.Errorf("closing the only connection should triggered balancer to send the all endpoints via Notify chan so that we can establish a connection") t.Errorf("closing the only connection should triggered balancer to send the %d endpoints via Notify chan so that we can establish a connection", len(endpoints)-1)
} }
down2(errors.New("error")) down2(errors.New("error"))
ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond*100) ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond*100)
_, _, err = sb.Get(ctx, blockingOpts) _, _, err = hb.Get(ctx, blockingOpts)
cancel() cancel()
if err != context.DeadlineExceeded { if err != context.DeadlineExceeded {
t.Errorf("Get() with no up endpoints should timeout, got %v", err) t.Errorf("Get() with no up endpoints should timeout, got %v", err)
@ -168,9 +166,8 @@ func TestHealthBalancerGraylist(t *testing.T) {
}() }()
} }
sb := newSimpleBalancer(eps)
tf := func(s string) (bool, error) { return false, nil } tf := func(s string) (bool, error) { return false, nil }
hb := newHealthBalancer(sb, 5*time.Second, tf) hb := newHealthBalancer(eps, 5*time.Second, tf)
conn, err := grpc.Dial("", grpc.WithInsecure(), grpc.WithBalancer(hb)) conn, err := grpc.Dial("", grpc.WithInsecure(), grpc.WithBalancer(hb))
testutil.AssertNil(t, err) testutil.AssertNil(t, err)
@ -203,13 +200,13 @@ func TestBalancerDoNotBlockOnClose(t *testing.T) {
defer kcl.close() defer kcl.close()
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
sb := newSimpleBalancer(kcl.endpoints()) hb := newHealthBalancer(kcl.endpoints(), minHealthRetryDuration, func(string) (bool, error) { return true, nil })
conn, err := grpc.Dial("", grpc.WithInsecure(), grpc.WithBalancer(sb)) conn, err := grpc.Dial("", grpc.WithInsecure(), grpc.WithBalancer(hb))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
kvc := pb.NewKVClient(conn) kvc := pb.NewKVClient(conn)
<-sb.readyc <-hb.readyc
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(100) wg.Add(100)
@ -225,7 +222,7 @@ func TestBalancerDoNotBlockOnClose(t *testing.T) {
bclosec, cclosec := make(chan struct{}), make(chan struct{}) bclosec, cclosec := make(chan struct{}), make(chan struct{})
go func() { go func() {
defer close(bclosec) defer close(bclosec)
sb.Close() hb.Close()
}() }()
go func() { go func() {
defer close(cclosec) defer close(cclosec)

View File

@ -472,8 +472,9 @@ func TestKVNewAfterClose(t *testing.T) {
donec := make(chan struct{}) donec := make(chan struct{})
go func() { go func() {
if _, err := cli.Get(context.TODO(), "foo"); err != context.Canceled { _, err := cli.Get(context.TODO(), "foo")
t.Fatalf("expected %v, got %v", context.Canceled, err) if err != context.Canceled && err != grpc.ErrClientConnClosing {
t.Fatalf("expected %v or %v, got %v", context.Canceled, grpc.ErrClientConnClosing, err)
} }
close(donec) close(donec)
}() }()

View File

@ -45,7 +45,8 @@ func isNonRepeatableStopError(err error) bool {
if ev.Code() != codes.Unavailable { if ev.Code() != codes.Unavailable {
return true return true
} }
return rpctypes.ErrorDesc(err) != "there is no address available" desc := rpctypes.ErrorDesc(err)
return desc != "there is no address available" && desc != "there is no connection available"
} }
func (c *Client) newRetryWrapper(isStop retryStopErrFunc) retryRPCFunc { func (c *Client) newRetryWrapper(isStop retryStopErrFunc) retryRPCFunc {

View File

@ -403,6 +403,6 @@ type pickFirst struct {
*roundRobin *roundRobin
} }
func pickFirstBalancer(r naming.Resolver) Balancer { func pickFirstBalancerV1(r naming.Resolver) Balancer {
return &pickFirst{&roundRobin{r: r}} return &pickFirst{&roundRobin{r: r}}
} }

206
cmd/vendor/google.golang.org/grpc/balancer/balancer.go generated vendored Normal file
View File

@ -0,0 +1,206 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package balancer defines APIs for load balancing in gRPC.
// All APIs in this package are experimental.
package balancer
import (
"errors"
"net"
"golang.org/x/net/context"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/resolver"
)
var (
// m is a map from name to balancer builder.
m = make(map[string]Builder)
// defaultBuilder is the default balancer to use.
defaultBuilder Builder // TODO(bar) install pickfirst as default.
)
// Register registers the balancer builder to the balancer map.
// b.Name will be used as the name registered with this builder.
func Register(b Builder) {
m[b.Name()] = b
}
// Get returns the resolver builder registered with the given name.
// If no builder is register with the name, the default pickfirst will
// be used.
func Get(name string) Builder {
if b, ok := m[name]; ok {
return b
}
return defaultBuilder
}
// SubConn represents a gRPC sub connection.
// Each sub connection contains a list of addresses. gRPC will
// try to connect to them (in sequence), and stop trying the
// remainder once one connection is successful.
//
// The reconnect backoff will be applied on the list, not a single address.
// For example, try_on_all_addresses -> backoff -> try_on_all_addresses.
//
// All SubConns start in IDLE, and will not try to connect. To trigger
// the connecting, Balancers must call Connect.
// When the connection encounters an error, it will reconnect immediately.
// When the connection becomes IDLE, it will not reconnect unless Connect is
// called.
type SubConn interface {
// UpdateAddresses updates the addresses used in this SubConn.
// gRPC checks if currently-connected address is still in the new list.
// If it's in the list, the connection will be kept.
// If it's not in the list, the connection will gracefully closed, and
// a new connection will be created.
//
// This will trigger a state transition for the SubConn.
UpdateAddresses([]resolver.Address)
// Connect starts the connecting for this SubConn.
Connect()
}
// NewSubConnOptions contains options to create new SubConn.
type NewSubConnOptions struct{}
// ClientConn represents a gRPC ClientConn.
type ClientConn interface {
// NewSubConn is called by balancer to create a new SubConn.
// It doesn't block and wait for the connections to be established.
// Behaviors of the SubConn can be controlled by options.
NewSubConn([]resolver.Address, NewSubConnOptions) (SubConn, error)
// RemoveSubConn removes the SubConn from ClientConn.
// The SubConn will be shutdown.
RemoveSubConn(SubConn)
// UpdateBalancerState is called by balancer to nofity gRPC that some internal
// state in balancer has changed.
//
// gRPC will update the connectivity state of the ClientConn, and will call pick
// on the new picker to pick new SubConn.
UpdateBalancerState(s connectivity.State, p Picker)
// Target returns the dial target for this ClientConn.
Target() string
}
// BuildOptions contains additional information for Build.
type BuildOptions struct {
// DialCreds is the transport credential the Balancer implementation can
// use to dial to a remote load balancer server. The Balancer implementations
// can ignore this if it does not need to talk to another party securely.
DialCreds credentials.TransportCredentials
// Dialer is the custom dialer the Balancer implementation can use to dial
// to a remote load balancer server. The Balancer implementations
// can ignore this if it doesn't need to talk to remote balancer.
Dialer func(context.Context, string) (net.Conn, error)
}
// Builder creates a balancer.
type Builder interface {
// Build creates a new balancer with the ClientConn.
Build(cc ClientConn, opts BuildOptions) Balancer
// Name returns the name of balancers built by this builder.
// It will be used to pick balancers (for example in service config).
Name() string
}
// PickOptions contains addition information for the Pick operation.
type PickOptions struct{}
// DoneInfo contains additional information for done.
type DoneInfo struct {
// Err is the rpc error the RPC finished with. It could be nil.
Err error
}
var (
// ErrNoSubConnAvailable indicates no SubConn is available for pick().
// gRPC will block the RPC until a new picker is available via UpdateBalancerState().
ErrNoSubConnAvailable = errors.New("no SubConn is available")
// ErrTransientFailure indicates all SubConns are in TransientFailure.
// WaitForReady RPCs will block, non-WaitForReady RPCs will fail.
ErrTransientFailure = errors.New("all SubConns are in TransientFailure")
)
// Picker is used by gRPC to pick a SubConn to send an RPC.
// Balancer is expected to generate a new picker from its snapshot everytime its
// internal state has changed.
//
// The pickers used by gRPC can be updated by ClientConn.UpdateBalancerState().
type Picker interface {
// Pick returns the SubConn to be used to send the RPC.
// The returned SubConn must be one returned by NewSubConn().
//
// This functions is expected to return:
// - a SubConn that is known to be READY;
// - ErrNoSubConnAvailable if no SubConn is available, but progress is being
// made (for example, some SubConn is in CONNECTING mode);
// - other errors if no active connecting is happening (for example, all SubConn
// are in TRANSIENT_FAILURE mode).
//
// If a SubConn is returned:
// - If it is READY, gRPC will send the RPC on it;
// - If it is not ready, or becomes not ready after it's returned, gRPC will block
// this call until a new picker is updated and will call pick on the new picker.
//
// If the returned error is not nil:
// - If the error is ErrNoSubConnAvailable, gRPC will block until UpdateBalancerState()
// - If the error is ErrTransientFailure:
// - If the RPC is wait-for-ready, gRPC will block until UpdateBalancerState()
// is called to pick again;
// - Otherwise, RPC will fail with unavailable error.
// - Else (error is other non-nil error):
// - The RPC will fail with unavailable error.
//
// The returned done() function will be called once the rpc has finished, with the
// final status of that RPC.
// done may be nil if balancer doesn't care about the RPC status.
Pick(ctx context.Context, opts PickOptions) (conn SubConn, done func(DoneInfo), err error)
}
// Balancer takes input from gRPC, manages SubConns, and collects and aggregates
// the connectivity states.
//
// It also generates and updates the Picker used by gRPC to pick SubConns for RPCs.
//
// HandleSubConnectionStateChange, HandleResolvedAddrs and Close are guaranteed
// to be called synchronously from the same goroutine.
// There's no guarantee on picker.Pick, it may be called anytime.
type Balancer interface {
// HandleSubConnStateChange is called by gRPC when the connectivity state
// of sc has changed.
// Balancer is expected to aggregate all the state of SubConn and report
// that back to gRPC.
// Balancer should also generate and update Pickers when its internal state has
// been changed by the new state.
HandleSubConnStateChange(sc SubConn, state connectivity.State)
// HandleResolvedAddrs is called by gRPC to send updated resolved addresses to
// balancers.
// Balancer can create new SubConn or remove SubConn with the addresses.
// An empty address slice and a non-nil error will be passed if the resolver returns
// non-nil error to gRPC.
HandleResolvedAddrs([]resolver.Address, error)
// Close closes the balancer. The balancer is not required to call
// ClientConn.RemoveSubConn for its existing SubConns.
Close()
}

View File

@ -0,0 +1,248 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"sync"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/resolver"
)
// scStateUpdate contains the subConn and the new state it changed to.
type scStateUpdate struct {
sc balancer.SubConn
state connectivity.State
}
// scStateUpdateBuffer is an unbounded channel for scStateChangeTuple.
// TODO make a general purpose buffer that uses interface{}.
type scStateUpdateBuffer struct {
c chan *scStateUpdate
mu sync.Mutex
backlog []*scStateUpdate
}
func newSCStateUpdateBuffer() *scStateUpdateBuffer {
return &scStateUpdateBuffer{
c: make(chan *scStateUpdate, 1),
}
}
func (b *scStateUpdateBuffer) put(t *scStateUpdate) {
b.mu.Lock()
defer b.mu.Unlock()
if len(b.backlog) == 0 {
select {
case b.c <- t:
return
default:
}
}
b.backlog = append(b.backlog, t)
}
func (b *scStateUpdateBuffer) load() {
b.mu.Lock()
defer b.mu.Unlock()
if len(b.backlog) > 0 {
select {
case b.c <- b.backlog[0]:
b.backlog[0] = nil
b.backlog = b.backlog[1:]
default:
}
}
}
// get returns the channel that receives a recvMsg in the buffer.
//
// Upon receiving, the caller should call load to send another
// scStateChangeTuple onto the channel if there is any.
func (b *scStateUpdateBuffer) get() <-chan *scStateUpdate {
return b.c
}
// resolverUpdate contains the new resolved addresses or error if there's
// any.
type resolverUpdate struct {
addrs []resolver.Address
err error
}
// ccBalancerWrapper is a wrapper on top of cc for balancers.
// It implements balancer.ClientConn interface.
type ccBalancerWrapper struct {
cc *ClientConn
balancer balancer.Balancer
stateChangeQueue *scStateUpdateBuffer
resolverUpdateCh chan *resolverUpdate
done chan struct{}
}
func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.BuildOptions) *ccBalancerWrapper {
ccb := &ccBalancerWrapper{
cc: cc,
stateChangeQueue: newSCStateUpdateBuffer(),
resolverUpdateCh: make(chan *resolverUpdate, 1),
done: make(chan struct{}),
}
go ccb.watcher()
ccb.balancer = b.Build(ccb, bopts)
return ccb
}
// watcher balancer functions sequencially, so the balancer can be implemeneted
// lock-free.
func (ccb *ccBalancerWrapper) watcher() {
for {
select {
case t := <-ccb.stateChangeQueue.get():
ccb.stateChangeQueue.load()
ccb.balancer.HandleSubConnStateChange(t.sc, t.state)
case t := <-ccb.resolverUpdateCh:
ccb.balancer.HandleResolvedAddrs(t.addrs, t.err)
case <-ccb.done:
}
select {
case <-ccb.done:
ccb.balancer.Close()
return
default:
}
}
}
func (ccb *ccBalancerWrapper) close() {
close(ccb.done)
}
func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
// When updating addresses for a SubConn, if the address in use is not in
// the new addresses, the old ac will be tearDown() and a new ac will be
// created. tearDown() generates a state change with Shutdown state, we
// don't want the balancer to receive this state change. So before
// tearDown() on the old ac, ac.acbw (acWrapper) will be set to nil, and
// this function will be called with (nil, Shutdown). We don't need to call
// balancer method in this case.
if sc == nil {
return
}
ccb.stateChangeQueue.put(&scStateUpdate{
sc: sc,
state: s,
})
}
func (ccb *ccBalancerWrapper) handleResolvedAddrs(addrs []resolver.Address, err error) {
select {
case <-ccb.resolverUpdateCh:
default:
}
ccb.resolverUpdateCh <- &resolverUpdate{
addrs: addrs,
err: err,
}
}
func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
grpclog.Infof("ccBalancerWrapper: new subconn: %v", addrs)
ac, err := ccb.cc.newAddrConn(addrs)
if err != nil {
return nil, err
}
acbw := &acBalancerWrapper{ac: ac}
ac.acbw = acbw
return acbw, nil
}
func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) {
grpclog.Infof("ccBalancerWrapper: removing subconn")
acbw, ok := sc.(*acBalancerWrapper)
if !ok {
return
}
ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain)
}
func (ccb *ccBalancerWrapper) UpdateBalancerState(s connectivity.State, p balancer.Picker) {
grpclog.Infof("ccBalancerWrapper: updating state and picker called by balancer: %v, %p", s, p)
ccb.cc.csMgr.updateState(s)
ccb.cc.blockingpicker.updatePicker(p)
}
func (ccb *ccBalancerWrapper) Target() string {
return ccb.cc.target
}
// acBalancerWrapper is a wrapper on top of ac for balancers.
// It implements balancer.SubConn interface.
type acBalancerWrapper struct {
mu sync.Mutex
ac *addrConn
}
func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
grpclog.Infof("acBalancerWrapper: UpdateAddresses called with %v", addrs)
acbw.mu.Lock()
defer acbw.mu.Unlock()
if !acbw.ac.tryUpdateAddrs(addrs) {
cc := acbw.ac.cc
acbw.ac.mu.Lock()
// Set old ac.acbw to nil so the Shutdown state update will be ignored
// by balancer.
//
// TODO(bar) the state transition could be wrong when tearDown() old ac
// and creating new ac, fix the transition.
acbw.ac.acbw = nil
acbw.ac.mu.Unlock()
acState := acbw.ac.getState()
acbw.ac.tearDown(errConnDrain)
if acState == connectivity.Shutdown {
return
}
ac, err := cc.newAddrConn(addrs)
if err != nil {
grpclog.Warningf("acBalancerWrapper: UpdateAddresses: failed to newAddrConn: %v", err)
return
}
acbw.ac = ac
ac.acbw = acbw
if acState != connectivity.Idle {
ac.connect(false)
}
}
}
func (acbw *acBalancerWrapper) Connect() {
acbw.mu.Lock()
defer acbw.mu.Unlock()
acbw.ac.connect(false)
}
func (acbw *acBalancerWrapper) getAddrConn() *addrConn {
acbw.mu.Lock()
defer acbw.mu.Unlock()
return acbw.ac
}

View File

@ -0,0 +1,367 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"sync"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/resolver"
)
type balancerWrapperBuilder struct {
b Balancer // The v1 balancer.
}
func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
bwb.b.Start(cc.Target(), BalancerConfig{
DialCreds: opts.DialCreds,
Dialer: opts.Dialer,
})
_, pickfirst := bwb.b.(*pickFirst)
bw := &balancerWrapper{
balancer: bwb.b,
pickfirst: pickfirst,
cc: cc,
startCh: make(chan struct{}),
conns: make(map[resolver.Address]balancer.SubConn),
connSt: make(map[balancer.SubConn]*scState),
csEvltr: &connectivityStateEvaluator{},
state: connectivity.Idle,
}
cc.UpdateBalancerState(connectivity.Idle, bw)
go bw.lbWatcher()
return bw
}
func (bwb *balancerWrapperBuilder) Name() string {
return "wrapper"
}
type scState struct {
addr Address // The v1 address type.
s connectivity.State
down func(error)
}
type balancerWrapper struct {
balancer Balancer // The v1 balancer.
pickfirst bool
cc balancer.ClientConn
// To aggregate the connectivity state.
csEvltr *connectivityStateEvaluator
state connectivity.State
mu sync.Mutex
conns map[resolver.Address]balancer.SubConn
connSt map[balancer.SubConn]*scState
// This channel is closed when handling the first resolver result.
// lbWatcher blocks until this is closed, to avoid race between
// - NewSubConn is created, cc wants to notify balancer of state changes;
// - Build hasn't return, cc doesn't have access to balancer.
startCh chan struct{}
}
// lbWatcher watches the Notify channel of the balancer and manages
// connections accordingly.
func (bw *balancerWrapper) lbWatcher() {
<-bw.startCh
grpclog.Infof("balancerWrapper: is pickfirst: %v\n", bw.pickfirst)
notifyCh := bw.balancer.Notify()
if notifyCh == nil {
// There's no resolver in the balancer. Connect directly.
a := resolver.Address{
Addr: bw.cc.Target(),
Type: resolver.Backend,
}
sc, err := bw.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{})
if err != nil {
grpclog.Warningf("Error creating connection to %v. Err: %v", a, err)
} else {
bw.mu.Lock()
bw.conns[a] = sc
bw.connSt[sc] = &scState{
addr: Address{Addr: bw.cc.Target()},
s: connectivity.Idle,
}
bw.mu.Unlock()
sc.Connect()
}
return
}
for addrs := range notifyCh {
grpclog.Infof("balancerWrapper: got update addr from Notify: %v\n", addrs)
if bw.pickfirst {
var (
oldA resolver.Address
oldSC balancer.SubConn
)
bw.mu.Lock()
for oldA, oldSC = range bw.conns {
break
}
bw.mu.Unlock()
if len(addrs) <= 0 {
if oldSC != nil {
// Teardown old sc.
bw.mu.Lock()
delete(bw.conns, oldA)
delete(bw.connSt, oldSC)
bw.mu.Unlock()
bw.cc.RemoveSubConn(oldSC)
}
continue
}
var newAddrs []resolver.Address
for _, a := range addrs {
newAddr := resolver.Address{
Addr: a.Addr,
Type: resolver.Backend, // All addresses from balancer are all backends.
ServerName: "",
Metadata: a.Metadata,
}
newAddrs = append(newAddrs, newAddr)
}
if oldSC == nil {
// Create new sc.
sc, err := bw.cc.NewSubConn(newAddrs, balancer.NewSubConnOptions{})
if err != nil {
grpclog.Warningf("Error creating connection to %v. Err: %v", newAddrs, err)
} else {
bw.mu.Lock()
// For pickfirst, there should be only one SubConn, so the
// address doesn't matter. All states updating (up and down)
// and picking should all happen on that only SubConn.
bw.conns[resolver.Address{}] = sc
bw.connSt[sc] = &scState{
addr: addrs[0], // Use the first address.
s: connectivity.Idle,
}
bw.mu.Unlock()
sc.Connect()
}
} else {
oldSC.UpdateAddresses(newAddrs)
bw.mu.Lock()
bw.connSt[oldSC].addr = addrs[0]
bw.mu.Unlock()
}
} else {
var (
add []resolver.Address // Addresses need to setup connections.
del []balancer.SubConn // Connections need to tear down.
)
resAddrs := make(map[resolver.Address]Address)
for _, a := range addrs {
resAddrs[resolver.Address{
Addr: a.Addr,
Type: resolver.Backend, // All addresses from balancer are all backends.
ServerName: "",
Metadata: a.Metadata,
}] = a
}
bw.mu.Lock()
for a := range resAddrs {
if _, ok := bw.conns[a]; !ok {
add = append(add, a)
}
}
for a, c := range bw.conns {
if _, ok := resAddrs[a]; !ok {
del = append(del, c)
delete(bw.conns, a)
// Keep the state of this sc in bw.connSt until its state becomes Shutdown.
}
}
bw.mu.Unlock()
for _, a := range add {
sc, err := bw.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{})
if err != nil {
grpclog.Warningf("Error creating connection to %v. Err: %v", a, err)
} else {
bw.mu.Lock()
bw.conns[a] = sc
bw.connSt[sc] = &scState{
addr: resAddrs[a],
s: connectivity.Idle,
}
bw.mu.Unlock()
sc.Connect()
}
}
for _, c := range del {
bw.cc.RemoveSubConn(c)
}
}
}
}
func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
grpclog.Infof("balancerWrapper: handle subconn state change: %p, %v", sc, s)
bw.mu.Lock()
defer bw.mu.Unlock()
scSt, ok := bw.connSt[sc]
if !ok {
return
}
if s == connectivity.Idle {
sc.Connect()
}
oldS := scSt.s
scSt.s = s
if oldS != connectivity.Ready && s == connectivity.Ready {
scSt.down = bw.balancer.Up(scSt.addr)
} else if oldS == connectivity.Ready && s != connectivity.Ready {
if scSt.down != nil {
scSt.down(errConnClosing)
}
}
sa := bw.csEvltr.recordTransition(oldS, s)
if bw.state != sa {
bw.state = sa
}
bw.cc.UpdateBalancerState(bw.state, bw)
if s == connectivity.Shutdown {
// Remove state for this sc.
delete(bw.connSt, sc)
}
return
}
func (bw *balancerWrapper) HandleResolvedAddrs([]resolver.Address, error) {
bw.mu.Lock()
defer bw.mu.Unlock()
select {
case <-bw.startCh:
default:
close(bw.startCh)
}
// There should be a resolver inside the balancer.
// All updates here, if any, are ignored.
return
}
func (bw *balancerWrapper) Close() {
bw.mu.Lock()
defer bw.mu.Unlock()
select {
case <-bw.startCh:
default:
close(bw.startCh)
}
bw.balancer.Close()
return
}
// The picker is the balancerWrapper itself.
// Pick should never return ErrNoSubConnAvailable.
// It either blocks or returns error, consistent with v1 balancer Get().
func (bw *balancerWrapper) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
failfast := true // Default failfast is true.
if ss, ok := rpcInfoFromContext(ctx); ok {
failfast = ss.failfast
}
a, p, err := bw.balancer.Get(ctx, BalancerGetOptions{BlockingWait: !failfast})
if err != nil {
return nil, nil, err
}
var done func(balancer.DoneInfo)
if p != nil {
done = func(i balancer.DoneInfo) { p() }
}
var sc balancer.SubConn
bw.mu.Lock()
defer bw.mu.Unlock()
if bw.pickfirst {
// Get the first sc in conns.
for _, sc = range bw.conns {
break
}
} else {
var ok bool
sc, ok = bw.conns[resolver.Address{
Addr: a.Addr,
Type: resolver.Backend,
ServerName: "",
Metadata: a.Metadata,
}]
if !ok && failfast {
return nil, nil, Errorf(codes.Unavailable, "there is no connection available")
}
if s, ok := bw.connSt[sc]; failfast && (!ok || s.s != connectivity.Ready) {
// If the returned sc is not ready and RPC is failfast,
// return error, and this RPC will fail.
return nil, nil, Errorf(codes.Unavailable, "there is no connection available")
}
}
return sc, done, nil
}
// connectivityStateEvaluator gets updated by addrConns when their
// states transition, based on which it evaluates the state of
// ClientConn.
type connectivityStateEvaluator struct {
mu sync.Mutex
numReady uint64 // Number of addrConns in ready state.
numConnecting uint64 // Number of addrConns in connecting state.
numTransientFailure uint64 // Number of addrConns in transientFailure.
}
// recordTransition records state change happening in every subConn and based on
// that it evaluates what aggregated state should be.
// It can only transition between Ready, Connecting and TransientFailure. Other states,
// Idle and Shutdown are transitioned into by ClientConn; in the beginning of the connection
// before any subConn is created ClientConn is in idle state. In the end when ClientConn
// closes it is in Shutdown state.
// TODO Note that in later releases, a ClientConn with no activity will be put into an Idle state.
func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State {
cse.mu.Lock()
defer cse.mu.Unlock()
// Update counters.
for idx, state := range []connectivity.State{oldState, newState} {
updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new.
switch state {
case connectivity.Ready:
cse.numReady += updateVal
case connectivity.Connecting:
cse.numConnecting += updateVal
case connectivity.TransientFailure:
cse.numTransientFailure += updateVal
}
}
// Evaluate.
if cse.numReady > 0 {
return connectivity.Ready
}
if cse.numConnecting > 0 {
return connectivity.Connecting
}
return connectivity.TransientFailure
}

View File

@ -25,6 +25,7 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -135,7 +136,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) {
c := defaultCallInfo c := defaultCallInfo()
mc := cc.GetMethodConfig(method) mc := cc.GetMethodConfig(method)
if mc.WaitForReady != nil { if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady c.failFast = !*mc.WaitForReady
@ -149,13 +150,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
opts = append(cc.dopts.callOptions, opts...) opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts { for _, o := range opts {
if err := o.before(&c); err != nil { if err := o.before(c); err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
} }
defer func() { defer func() {
for _, o := range opts { for _, o := range opts {
o.after(&c) o.after(c)
} }
}() }()
@ -178,7 +179,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
}() }()
} }
ctx = newContextWithRPCInfo(ctx) ctx = newContextWithRPCInfo(ctx, c.failFast)
sh := cc.dopts.copts.StatsHandler sh := cc.dopts.copts.StatsHandler
if sh != nil { if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
@ -206,9 +207,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
err error err error
t transport.ClientTransport t transport.ClientTransport
stream *transport.Stream stream *transport.Stream
// Record the put handler from Balancer.Get(...). It is called once the // Record the done handler from Balancer.Get(...). It is called once the
// RPC has completed or failed. // RPC has completed or failed.
put func() done func(balancer.DoneInfo)
) )
// TODO(zhaoq): Need a formal spec of fail-fast. // TODO(zhaoq): Need a formal spec of fail-fast.
callHdr := &transport.CallHdr{ callHdr := &transport.CallHdr{
@ -222,10 +223,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
callHdr.Creds = c.creds callHdr.Creds = c.creds
} }
gopts := BalancerGetOptions{ t, done, err = cc.getTransport(ctx, c.failFast)
BlockingWait: !c.failFast,
}
t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if _, ok := status.FromError(err); ok { if _, ok := status.FromError(err); ok {
@ -245,14 +243,14 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
stream, err = t.NewStream(ctx, callHdr) stream, err = t.NewStream(ctx, callHdr)
if err != nil { if err != nil {
if put != nil { if done != nil {
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok {
// If error is connection error, transport was sending data on wire, // If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire. // and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent. // If error is not connection error, we are sure nothing has been sent.
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
} }
put() done(balancer.DoneInfo{Err: err})
} }
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue continue
@ -262,14 +260,14 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if peer, ok := peer.FromContext(stream.Context()); ok { if peer, ok := peer.FromContext(stream.Context()); ok {
c.peer = peer c.peer = peer
} }
err = sendRequest(ctx, cc.dopts, cc.dopts.cp, &c, callHdr, stream, t, args, topts) err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts)
if err != nil { if err != nil {
if put != nil { if done != nil {
updateRPCInfoInContext(ctx, rpcInfo{ updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(), bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(), bytesReceived: stream.BytesReceived(),
}) })
put() done(balancer.DoneInfo{Err: err})
} }
// Retry a non-failfast RPC when // Retry a non-failfast RPC when
// i) there is a connection error; or // i) there is a connection error; or
@ -279,14 +277,14 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
return toRPCErr(err) return toRPCErr(err)
} }
err = recvResponse(ctx, cc.dopts, t, &c, stream, reply) err = recvResponse(ctx, cc.dopts, t, c, stream, reply)
if err != nil { if err != nil {
if put != nil { if done != nil {
updateRPCInfoInContext(ctx, rpcInfo{ updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(), bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(), bytesReceived: stream.BytesReceived(),
}) })
put() done(balancer.DoneInfo{Err: err})
} }
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue continue
@ -297,12 +295,12 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
} }
t.CloseStream(stream, nil) t.CloseStream(stream, nil)
if put != nil { if done != nil {
updateRPCInfoInContext(ctx, rpcInfo{ updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(), bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(), bytesReceived: stream.BytesReceived(),
}) })
put() done(balancer.DoneInfo{Err: err})
} }
return stream.Status().Err() return stream.Status().Err()
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,98 +0,0 @@
// +build go1.6,!go1.7
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"fmt"
"io"
"net"
"net/http"
"os"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
}
func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {
req.Cancel = ctx.Done()
if err := req.Write(conn); err != nil {
return fmt.Errorf("failed to write the HTTP request: %v", err)
}
return nil
}
// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
if _, ok := status.FromError(err); ok {
return err
}
switch e := err.(type) {
case transport.StreamError:
return status.Error(e.Code, e.Desc)
case transport.ConnectionError:
return status.Error(codes.Unavailable, e.Desc)
default:
switch err {
case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
}
}
return status.Error(codes.Unknown, err.Error())
}
// convertCode converts a standard Go error into its canonical code. Note that
// this is only used to translate the error returned by the server applications.
func convertCode(err error) codes.Code {
switch err {
case nil:
return codes.OK
case io.EOF:
return codes.OutOfRange
case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF:
return codes.FailedPrecondition
case os.ErrInvalid:
return codes.InvalidArgument
case context.Canceled:
return codes.Canceled
case context.DeadlineExceeded:
return codes.DeadlineExceeded
}
switch {
case os.IsExist(err):
return codes.AlreadyExists
case os.IsNotExist(err):
return codes.NotFound
case os.IsPermission(err):
return codes.PermissionDenied
}
return codes.Unknown
}

View File

@ -1,98 +0,0 @@
// +build go1.7
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"context"
"io"
"net"
"net/http"
"os"
netctx "golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, network, address)
}
func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {
req = req.WithContext(ctx)
if err := req.Write(conn); err != nil {
return err
}
return nil
}
// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
if _, ok := status.FromError(err); ok {
return err
}
switch e := err.(type) {
case transport.StreamError:
return status.Error(e.Code, e.Desc)
case transport.ConnectionError:
return status.Error(codes.Unavailable, e.Desc)
default:
switch err {
case context.DeadlineExceeded, netctx.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled, netctx.Canceled:
return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
}
}
return status.Error(codes.Unknown, err.Error())
}
// convertCode converts a standard Go error into its canonical code. Note that
// this is only used to translate the error returned by the server applications.
func convertCode(err error) codes.Code {
switch err {
case nil:
return codes.OK
case io.EOF:
return codes.OutOfRange
case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF:
return codes.FailedPrecondition
case os.ErrInvalid:
return codes.InvalidArgument
case context.Canceled, netctx.Canceled:
return codes.Canceled
case context.DeadlineExceeded, netctx.DeadlineExceeded:
return codes.DeadlineExceeded
}
switch {
case os.IsExist(err):
return codes.AlreadyExists
case os.IsNotExist(err):
return codes.NotFound
case os.IsPermission(err):
return codes.PermissionDenied
}
return codes.Unknown
}

View File

@ -73,7 +73,7 @@ func (x *balanceLoadClientStream) Recv() (*lbmpb.LoadBalanceResponse, error) {
// NewGRPCLBBalancer creates a grpclb load balancer. // NewGRPCLBBalancer creates a grpclb load balancer.
func NewGRPCLBBalancer(r naming.Resolver) Balancer { func NewGRPCLBBalancer(r naming.Resolver) Balancer {
return &balancer{ return &grpclbBalancer{
r: r, r: r,
} }
} }
@ -96,7 +96,7 @@ type grpclbAddrInfo struct {
dropForLoadBalancing bool dropForLoadBalancing bool
} }
type balancer struct { type grpclbBalancer struct {
r naming.Resolver r naming.Resolver
target string target string
mu sync.Mutex mu sync.Mutex
@ -113,7 +113,7 @@ type balancer struct {
clientStats lbmpb.ClientStats clientStats lbmpb.ClientStats
} }
func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error { func (b *grpclbBalancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error {
updates, err := w.Next() updates, err := w.Next()
if err != nil { if err != nil {
grpclog.Warningf("grpclb: failed to get next addr update from watcher: %v", err) grpclog.Warningf("grpclb: failed to get next addr update from watcher: %v", err)
@ -187,7 +187,7 @@ func convertDuration(d *lbmpb.Duration) time.Duration {
return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond
} }
func (b *balancer) processServerList(l *lbmpb.ServerList, seq int) { func (b *grpclbBalancer) processServerList(l *lbmpb.ServerList, seq int) {
if l == nil { if l == nil {
return return
} }
@ -230,7 +230,7 @@ func (b *balancer) processServerList(l *lbmpb.ServerList, seq int) {
return return
} }
func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) { func (b *grpclbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) {
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
defer ticker.Stop() defer ticker.Stop()
for { for {
@ -259,7 +259,7 @@ func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Dura
} }
} }
func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) { func (b *grpclbBalancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
stream, err := lbc.BalanceLoad(ctx) stream, err := lbc.BalanceLoad(ctx)
@ -332,7 +332,7 @@ func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry b
return true return true
} }
func (b *balancer) Start(target string, config BalancerConfig) error { func (b *grpclbBalancer) Start(target string, config BalancerConfig) error {
b.rand = rand.New(rand.NewSource(time.Now().Unix())) b.rand = rand.New(rand.NewSource(time.Now().Unix()))
// TODO: Fall back to the basic direct connection if there is no name resolver. // TODO: Fall back to the basic direct connection if there is no name resolver.
if b.r == nil { if b.r == nil {
@ -461,8 +461,11 @@ func (b *balancer) Start(target string, config BalancerConfig) error {
// WithDialer takes a different type of function, so we instead use a special DialOption here. // WithDialer takes a different type of function, so we instead use a special DialOption here.
dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer }) dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer })
} }
dopts = append(dopts, WithBlock())
ccError = make(chan struct{}) ccError = make(chan struct{})
cc, err = Dial(rb.addr, dopts...) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cc, err = DialContext(ctx, rb.addr, dopts...)
cancel()
if err != nil { if err != nil {
grpclog.Warningf("grpclb: failed to setup a connection to the remote balancer %v: %v", rb.addr, err) grpclog.Warningf("grpclb: failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
close(ccError) close(ccError)
@ -488,7 +491,7 @@ func (b *balancer) Start(target string, config BalancerConfig) error {
return nil return nil
} }
func (b *balancer) down(addr Address, err error) { func (b *grpclbBalancer) down(addr Address, err error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
for _, a := range b.addrs { for _, a := range b.addrs {
@ -499,7 +502,7 @@ func (b *balancer) down(addr Address, err error) {
} }
} }
func (b *balancer) Up(addr Address) func(error) { func (b *grpclbBalancer) Up(addr Address) func(error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
if b.done { if b.done {
@ -527,7 +530,7 @@ func (b *balancer) Up(addr Address) func(error) {
} }
} }
func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) { func (b *grpclbBalancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) {
var ch chan struct{} var ch chan struct{}
b.mu.Lock() b.mu.Lock()
if b.done { if b.done {
@ -597,19 +600,12 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
} }
} }
if !opts.BlockingWait { if !opts.BlockingWait {
if len(b.addrs) == 0 {
b.clientStats.NumCallsFinished++ b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithClientFailedToSend++ b.clientStats.NumCallsFinishedWithClientFailedToSend++
b.mu.Unlock() b.mu.Unlock()
err = Errorf(codes.Unavailable, "there is no address available") err = Errorf(codes.Unavailable, "there is no address available")
return return
} }
// Returns the next addr on b.addrs for a failfast RPC.
addr = b.addrs[b.next].addr
b.next++
b.mu.Unlock()
return
}
// Wait on b.waitCh for non-failfast RPCs. // Wait on b.waitCh for non-failfast RPCs.
if b.waitCh == nil { if b.waitCh == nil {
ch = make(chan struct{}) ch = make(chan struct{})
@ -684,11 +680,11 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre
} }
} }
func (b *balancer) Notify() <-chan []Address { func (b *grpclbBalancer) Notify() <-chan []Address {
return b.addrCh return b.addrCh
} }
func (b *balancer) Close() error { func (b *grpclbBalancer) Close() error {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
if b.done { if b.done {

View File

@ -1,4 +1,4 @@
// +build go1.6, !go1.8 // +build go1.7, !go1.8
/* /*
* *

141
cmd/vendor/google.golang.org/grpc/picker_wrapper.go generated vendored Normal file
View File

@ -0,0 +1,141 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"sync"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
// pickerWrapper is a wrapper of balancer.Picker. It blocks on certain pick
// actions and unblock when there's a picker update.
type pickerWrapper struct {
mu sync.Mutex
done bool
blockingCh chan struct{}
picker balancer.Picker
}
func newPickerWrapper() *pickerWrapper {
bp := &pickerWrapper{blockingCh: make(chan struct{})}
return bp
}
// updatePicker is called by UpdateBalancerState. It unblocks all blocked pick.
func (bp *pickerWrapper) updatePicker(p balancer.Picker) {
bp.mu.Lock()
if bp.done {
bp.mu.Unlock()
return
}
bp.picker = p
// bp.blockingCh should never be nil.
close(bp.blockingCh)
bp.blockingCh = make(chan struct{})
bp.mu.Unlock()
}
// pick returns the transport that will be used for the RPC.
// It may block in the following cases:
// - there's no picker
// - the current picker returns ErrNoSubConnAvailable
// - the current picker returns other errors and failfast is false.
// - the subConn returned by the current picker is not READY
// When one of these situations happens, pick blocks until the picker gets updated.
func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.PickOptions) (transport.ClientTransport, func(balancer.DoneInfo), error) {
var (
p balancer.Picker
ch chan struct{}
)
for {
bp.mu.Lock()
if bp.done {
bp.mu.Unlock()
return nil, nil, ErrClientConnClosing
}
if bp.picker == nil {
ch = bp.blockingCh
}
if ch == bp.blockingCh {
// This could happen when either:
// - bp.picker is nil (the previous if condition), or
// - has called pick on the current picker.
bp.mu.Unlock()
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
case <-ch:
}
continue
}
ch = bp.blockingCh
p = bp.picker
bp.mu.Unlock()
subConn, put, err := p.Pick(ctx, opts)
if err != nil {
switch err {
case balancer.ErrNoSubConnAvailable:
continue
case balancer.ErrTransientFailure:
if !failfast {
continue
}
return nil, nil, status.Errorf(codes.Unavailable, "%v", err)
default:
// err is some other error.
return nil, nil, toRPCErr(err)
}
}
acw, ok := subConn.(*acBalancerWrapper)
if !ok {
grpclog.Infof("subconn returned from pick is not *acBalancerWrapper")
continue
}
if t, ok := acw.getAddrConn().getReadyTransport(); ok {
return t, put, nil
}
grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick")
// If ok == false, ac.state is not READY.
// A valid picker always returns READY subConn. This means the state of ac
// just changed, and picker will be updated shortly.
// continue back to the beginning of the for loop to repick.
}
}
func (bp *pickerWrapper) close() {
bp.mu.Lock()
defer bp.mu.Unlock()
if bp.done {
return
}
bp.done = true
close(bp.blockingCh)
}

95
cmd/vendor/google.golang.org/grpc/pickfirst.go generated vendored Normal file
View File

@ -0,0 +1,95 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/resolver"
)
func newPickfirstBuilder() balancer.Builder {
return &pickfirstBuilder{}
}
type pickfirstBuilder struct{}
func (*pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
return &pickfirstBalancer{cc: cc}
}
func (*pickfirstBuilder) Name() string {
return "pickfirst"
}
type pickfirstBalancer struct {
cc balancer.ClientConn
sc balancer.SubConn
}
func (b *pickfirstBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) {
if err != nil {
grpclog.Infof("pickfirstBalancer: HandleResolvedAddrs called with error %v", err)
return
}
if b.sc == nil {
b.sc, err = b.cc.NewSubConn(addrs, balancer.NewSubConnOptions{})
if err != nil {
grpclog.Errorf("pickfirstBalancer: failed to NewSubConn: %v", err)
return
}
b.cc.UpdateBalancerState(connectivity.Idle, &picker{sc: b.sc})
} else {
b.sc.UpdateAddresses(addrs)
}
}
func (b *pickfirstBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
grpclog.Infof("pickfirstBalancer: HandleSubConnStateChange: %p, %v", sc, s)
if b.sc != sc || s == connectivity.Shutdown {
b.sc = nil
return
}
switch s {
case connectivity.Ready, connectivity.Idle:
b.cc.UpdateBalancerState(s, &picker{sc: sc})
case connectivity.Connecting:
b.cc.UpdateBalancerState(s, &picker{err: balancer.ErrNoSubConnAvailable})
case connectivity.TransientFailure:
b.cc.UpdateBalancerState(s, &picker{err: balancer.ErrTransientFailure})
}
}
func (b *pickfirstBalancer) Close() {
}
type picker struct {
err error
sc balancer.SubConn
}
func (p *picker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
if p.err != nil {
return nil, nil, p.err
}
return p.sc, nil, nil
}

View File

@ -82,7 +82,8 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_
Header: map[string][]string{"User-Agent": {grpcUA}}, Header: map[string][]string{"User-Agent": {grpcUA}},
}) })
if err := sendHTTPRequest(ctx, req, conn); err != nil { req = req.WithContext(ctx)
if err := req.Write(conn); err != nil {
return nil, fmt.Errorf("failed to write the HTTP request: %v", err) return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
} }

143
cmd/vendor/google.golang.org/grpc/resolver/resolver.go generated vendored Normal file
View File

@ -0,0 +1,143 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package resolver defines APIs for name resolution in gRPC.
// All APIs in this package are experimental.
package resolver
var (
// m is a map from scheme to resolver builder.
m = make(map[string]Builder)
// defaultScheme is the default scheme to use.
defaultScheme string
)
// TODO(bar) install dns resolver in init(){}.
// Register registers the resolver builder to the resolver map.
// b.Scheme will be used as the scheme registered with this builder.
func Register(b Builder) {
m[b.Scheme()] = b
}
// Get returns the resolver builder registered with the given scheme.
// If no builder is register with the scheme, the default scheme will
// be used.
// If the default scheme is not modified, "dns" will be the default
// scheme, and the preinstalled dns resolver will be used.
// If the default scheme is modified, and a resolver is registered with
// the scheme, that resolver will be returned.
// If the default scheme is modified, and no resolver is registered with
// the scheme, nil will be returned.
func Get(scheme string) Builder {
if b, ok := m[scheme]; ok {
return b
}
if b, ok := m[defaultScheme]; ok {
return b
}
return nil
}
// SetDefaultScheme sets the default scheme that will be used.
// The default default scheme is "dns".
func SetDefaultScheme(scheme string) {
defaultScheme = scheme
}
// AddressType indicates the address type returned by name resolution.
type AddressType uint8
const (
// Backend indicates the address is for a backend server.
Backend AddressType = iota
// GRPCLB indicates the address is for a grpclb load balancer.
GRPCLB
)
// Address represents a server the client connects to.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
type Address struct {
// Addr is the server address on which a connection will be established.
Addr string
// Type is the type of this address.
Type AddressType
// ServerName is the name of this address.
// It's the name of the grpc load balancer, which will be used for authentication.
ServerName string
// Metadata is the information associated with Addr, which may be used
// to make load balancing decision.
Metadata interface{}
}
// BuildOption includes additional information for the builder to create
// the resolver.
type BuildOption struct {
}
// ClientConn contains the callbacks for resolver to notify any updates
// to the gRPC ClientConn.
type ClientConn interface {
// NewAddress is called by resolver to notify ClientConn a new list
// of resolved addresses.
// The address list should be the complete list of resolved addresses.
NewAddress(addresses []Address)
// NewServiceConfig is called by resolver to notify ClientConn a new
// service config. The service config should be provided as a json string.
NewServiceConfig(serviceConfig string)
}
// Target represents a target for gRPC, as specified in:
// https://github.com/grpc/grpc/blob/master/doc/naming.md.
type Target struct {
Scheme string
Authority string
Endpoint string
}
// Builder creates a resolver that will be used to watch name resolution updates.
type Builder interface {
// Build creates a new resolver for the given target.
//
// gRPC dial calls Build synchronously, and fails if the returned error is
// not nil.
Build(target Target, cc ClientConn, opts BuildOption) (Resolver, error)
// Scheme returns the scheme supported by this resolver.
// Scheme is defined at https://github.com/grpc/grpc/blob/master/doc/naming.md.
Scheme() string
}
// ResolveNowOption includes additional information for ResolveNow.
type ResolveNowOption struct{}
// Resolver watches for the updates on the specified target.
// Updates include address updates and service config updates.
type Resolver interface {
// ResolveNow will be called by gRPC to try to resolve the target name again.
// It's just a hint, resolver can ignore this if it's not necessary.
ResolveNow(ResolveNowOption)
// Close closes the resolver.
Close()
}
// UnregisterForTesting removes the resolver builder with the given scheme from the
// resolver map.
// This function is for testing only.
func UnregisterForTesting(scheme string) {
delete(m, scheme)
}

View File

@ -0,0 +1,139 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"strings"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/resolver"
)
// ccResolverWrapper is a wrapper on top of cc for resolvers.
// It implements resolver.ClientConnection interface.
type ccResolverWrapper struct {
cc *ClientConn
resolver resolver.Resolver
addrCh chan []resolver.Address
scCh chan string
done chan struct{}
}
// split2 returns the values from strings.SplitN(s, sep, 2).
// If sep is not found, it returns "", s instead.
func split2(s, sep string) (string, string) {
spl := strings.SplitN(s, sep, 2)
if len(spl) < 2 {
return "", s
}
return spl[0], spl[1]
}
// parseTarget splits target into a struct containing scheme, authority and
// endpoint.
func parseTarget(target string) (ret resolver.Target) {
ret.Scheme, ret.Endpoint = split2(target, "://")
ret.Authority, ret.Endpoint = split2(ret.Endpoint, "/")
return ret
}
// newCCResolverWrapper parses cc.target for scheme and gets the resolver
// builder for this scheme. It then builds the resolver and starts the
// monitoring goroutine for it.
//
// This function could return nil, nil, in tests for old behaviors.
// TODO(bar) never return nil, nil when DNS becomes the default resolver.
func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) {
target := parseTarget(cc.target)
grpclog.Infof("dialing to target with scheme: %q", target.Scheme)
rb := resolver.Get(target.Scheme)
if rb == nil {
// TODO(bar) return error when DNS becomes the default (implemented and
// registered by DNS package).
grpclog.Infof("could not get resolver for scheme: %q", target.Scheme)
return nil, nil
}
ccr := &ccResolverWrapper{
cc: cc,
addrCh: make(chan []resolver.Address, 1),
scCh: make(chan string, 1),
done: make(chan struct{}),
}
var err error
ccr.resolver, err = rb.Build(target, ccr, resolver.BuildOption{})
if err != nil {
return nil, err
}
go ccr.watcher()
return ccr, nil
}
// watcher processes address updates and service config updates sequencially.
// Otherwise, we need to resolve possible races between address and service
// config (e.g. they specify different balancer types).
func (ccr *ccResolverWrapper) watcher() {
for {
select {
case <-ccr.done:
return
default:
}
select {
case addrs := <-ccr.addrCh:
grpclog.Infof("ccResolverWrapper: sending new addresses to balancer wrapper: %v", addrs)
// TODO(bar switching) this should never be nil. Pickfirst should be default.
if ccr.cc.balancerWrapper != nil {
// TODO(bar switching) create balancer if it's nil?
ccr.cc.balancerWrapper.handleResolvedAddrs(addrs, nil)
}
case sc := <-ccr.scCh:
grpclog.Infof("ccResolverWrapper: got new service config: %v", sc)
case <-ccr.done:
return
}
}
}
func (ccr *ccResolverWrapper) close() {
ccr.resolver.Close()
close(ccr.done)
}
// NewAddress is called by the resolver implemenetion to send addresses to gRPC.
func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) {
select {
case <-ccr.addrCh:
default:
}
ccr.addrCh <- addrs
}
// NewServiceConfig is called by the resolver implemenetion to send service
// configs to gPRC.
func (ccr *ccResolverWrapper) NewServiceConfig(sc string) {
select {
case <-ccr.scCh:
default:
}
ccr.scCh <- sc
}

View File

@ -21,10 +21,12 @@ package grpc
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
stdctx "context"
"encoding/binary" "encoding/binary"
"io" "io"
"io/ioutil" "io/ioutil"
"math" "math"
"os"
"sync" "sync"
"time" "time"
@ -132,7 +134,9 @@ type callInfo struct {
creds credentials.PerRPCCredentials creds credentials.PerRPCCredentials
} }
var defaultCallInfo = callInfo{failFast: true} func defaultCallInfo() *callInfo {
return &callInfo{failFast: true}
}
// CallOption configures a Call before it starts or extracts information from // CallOption configures a Call before it starts or extracts information from
// a Call after it completes. // a Call after it completes.
@ -384,14 +388,15 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
} }
type rpcInfo struct { type rpcInfo struct {
failfast bool
bytesSent bool bytesSent bool
bytesReceived bool bytesReceived bool
} }
type rpcInfoContextKey struct{} type rpcInfoContextKey struct{}
func newContextWithRPCInfo(ctx context.Context) context.Context { func newContextWithRPCInfo(ctx context.Context, failfast bool) context.Context {
return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{}) return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{failfast: failfast})
} }
func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) { func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) {
@ -401,11 +406,63 @@ func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) {
func updateRPCInfoInContext(ctx context.Context, s rpcInfo) { func updateRPCInfoInContext(ctx context.Context, s rpcInfo) {
if ss, ok := rpcInfoFromContext(ctx); ok { if ss, ok := rpcInfoFromContext(ctx); ok {
*ss = s ss.bytesReceived = s.bytesReceived
ss.bytesSent = s.bytesSent
} }
return return
} }
// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
if _, ok := status.FromError(err); ok {
return err
}
switch e := err.(type) {
case transport.StreamError:
return status.Error(e.Code, e.Desc)
case transport.ConnectionError:
return status.Error(codes.Unavailable, e.Desc)
default:
switch err {
case context.DeadlineExceeded, stdctx.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled, stdctx.Canceled:
return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
}
}
return status.Error(codes.Unknown, err.Error())
}
// convertCode converts a standard Go error into its canonical code. Note that
// this is only used to translate the error returned by the server applications.
func convertCode(err error) codes.Code {
switch err {
case nil:
return codes.OK
case io.EOF:
return codes.OutOfRange
case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF:
return codes.FailedPrecondition
case os.ErrInvalid:
return codes.InvalidArgument
case context.Canceled, stdctx.Canceled:
return codes.Canceled
case context.DeadlineExceeded, stdctx.DeadlineExceeded:
return codes.DeadlineExceeded
}
switch {
case os.IsExist(err):
return codes.AlreadyExists
case os.IsNotExist(err):
return codes.NotFound
case os.IsPermission(err):
return codes.PermissionDenied
}
return codes.Unknown
}
// Code returns the error code for err if it was produced by the rpc system. // Code returns the error code for err if it was produced by the rpc system.
// Otherwise, it returns codes.Unknown. // Otherwise, it returns codes.Unknown.
// //
@ -452,7 +509,7 @@ type MethodConfig struct {
// MaxReqSize is the maximum allowed payload size for an individual request in a // MaxReqSize is the maximum allowed payload size for an individual request in a
// stream (client->server) in bytes. The size which is measured is the serialized // stream (client->server) in bytes. The size which is measured is the serialized
// payload after per-message compression (but before stream compression) in bytes. // payload after per-message compression (but before stream compression) in bytes.
// The actual value used is the minumum of the value specified here and the value set // The actual value used is the minimum of the value specified here and the value set
// by the application via the gRPC client API. If either one is not set, then the other // by the application via the gRPC client API. If either one is not set, then the other
// will be used. If neither is set, then the built-in default is used. // will be used. If neither is set, then the built-in default is used.
MaxReqSize *int MaxReqSize *int
@ -497,7 +554,7 @@ func getMaxSize(mcMax, doptMax *int, defaultVal int) *int {
// SupportPackageIsVersion3 is referenced from generated protocol buffer files. // SupportPackageIsVersion3 is referenced from generated protocol buffer files.
// The latest support package version is 4. // The latest support package version is 4.
// SupportPackageIsVersion3 is kept for compability. It will be removed in the // SupportPackageIsVersion3 is kept for compatibility. It will be removed in the
// next support package version update. // next support package version update.
const SupportPackageIsVersion3 = true const SupportPackageIsVersion3 = true
@ -510,6 +567,6 @@ const SupportPackageIsVersion3 = true
const SupportPackageIsVersion4 = true const SupportPackageIsVersion4 = true
// Version is the current grpc version. // Version is the current grpc version.
const Version = "1.6.0" const Version = "1.7.2"
const grpcUA = "grpc-go/" + Version const grpcUA = "grpc-go/" + Version

View File

@ -116,6 +116,8 @@ type options struct {
keepalivePolicy keepalive.EnforcementPolicy keepalivePolicy keepalive.EnforcementPolicy
initialWindowSize int32 initialWindowSize int32
initialConnWindowSize int32 initialConnWindowSize int32
writeBufferSize int
readBufferSize int
} }
var defaultServerOptions = options{ var defaultServerOptions = options{
@ -126,6 +128,22 @@ var defaultServerOptions = options{
// A ServerOption sets options such as credentials, codec and keepalive parameters, etc. // A ServerOption sets options such as credentials, codec and keepalive parameters, etc.
type ServerOption func(*options) type ServerOption func(*options)
// WriteBufferSize lets you set the size of write buffer, this determines how much data can be batched
// before doing a write on the wire.
func WriteBufferSize(s int) ServerOption {
return func(o *options) {
o.writeBufferSize = s
}
}
// ReadBufferSize lets you set the size of read buffer, this determines how much data can be read at most
// for one read syscall.
func ReadBufferSize(s int) ServerOption {
return func(o *options) {
o.readBufferSize = s
}
}
// InitialWindowSize returns a ServerOption that sets window size for stream. // InitialWindowSize returns a ServerOption that sets window size for stream.
// The lower bound for window size is 64K and any value smaller than that will be ignored. // The lower bound for window size is 64K and any value smaller than that will be ignored.
func InitialWindowSize(s int32) ServerOption { func InitialWindowSize(s int32) ServerOption {
@ -260,7 +278,7 @@ func StatsHandler(h stats.Handler) ServerOption {
// handler that will be invoked instead of returning the "unimplemented" gRPC // handler that will be invoked instead of returning the "unimplemented" gRPC
// error whenever a request is received for an unregistered service or method. // error whenever a request is received for an unregistered service or method.
// The handling function has full access to the Context of the request and the // The handling function has full access to the Context of the request and the
// stream, and the invocation passes through interceptors. // stream, and the invocation bypasses interceptors.
func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
return func(o *options) { return func(o *options) {
o.unknownStreamDesc = &StreamDesc{ o.unknownStreamDesc = &StreamDesc{
@ -524,6 +542,8 @@ func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo)
KeepalivePolicy: s.opts.keepalivePolicy, KeepalivePolicy: s.opts.keepalivePolicy,
InitialWindowSize: s.opts.initialWindowSize, InitialWindowSize: s.opts.initialWindowSize,
InitialConnWindowSize: s.opts.initialConnWindowSize, InitialConnWindowSize: s.opts.initialConnWindowSize,
WriteBufferSize: s.opts.writeBufferSize,
ReadBufferSize: s.opts.readBufferSize,
} }
st, err := transport.NewServerTransport("http2", c, config) st, err := transport.NewServerTransport("http2", c, config)
if err != nil { if err != nil {
@ -891,9 +911,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
trInfo: trInfo, trInfo: trInfo,
statsHandler: sh, statsHandler: sh,
} }
if ss.cp != nil {
ss.cbuf = new(bytes.Buffer)
}
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&trInfo.firstLine, false) trInfo.tr.LazyLog(&trInfo.firstLine, false)
defer func() { defer func() {

View File

@ -135,8 +135,6 @@ func (s *OutPayload) isRPCStats() {}
type OutHeader struct { type OutHeader struct {
// Client is true if this OutHeader is from client side. // Client is true if this OutHeader is from client side.
Client bool Client bool
// WireLength is the wire length of header.
WireLength int
// The following fields are valid only if Client is true. // The following fields are valid only if Client is true.
// FullMethod is the full RPC method string, i.e., /package.service/method. // FullMethod is the full RPC method string, i.e., /package.service/method.
@ -220,7 +218,7 @@ type outgoingTagsKey struct{}
// the outgoing RPC with the header grpc-tags-bin. Subsequent calls to // the outgoing RPC with the header grpc-tags-bin. Subsequent calls to
// SetTags will overwrite the values from earlier calls. // SetTags will overwrite the values from earlier calls.
// //
// NOTE: this is provided only for backward compatibilty with existing clients // NOTE: this is provided only for backward compatibility with existing clients
// and will likely be removed in an upcoming release. New uses should transmit // and will likely be removed in an upcoming release. New uses should transmit
// this type of data using metadata with a different, non-reserved (i.e. does // this type of data using metadata with a different, non-reserved (i.e. does
// not begin with "grpc-") header name. // not begin with "grpc-") header name.
@ -230,7 +228,7 @@ func SetTags(ctx context.Context, b []byte) context.Context {
// Tags returns the tags from the context for the inbound RPC. // Tags returns the tags from the context for the inbound RPC.
// //
// NOTE: this is provided only for backward compatibilty with existing clients // NOTE: this is provided only for backward compatibility with existing clients
// and will likely be removed in an upcoming release. New uses should transmit // and will likely be removed in an upcoming release. New uses should transmit
// this type of data using metadata with a different, non-reserved (i.e. does // this type of data using metadata with a different, non-reserved (i.e. does
// not begin with "grpc-") header name. // not begin with "grpc-") header name.
@ -262,7 +260,7 @@ type outgoingTraceKey struct{}
// the outgoing RPC with the header grpc-trace-bin. Subsequent calls to // the outgoing RPC with the header grpc-trace-bin. Subsequent calls to
// SetTrace will overwrite the values from earlier calls. // SetTrace will overwrite the values from earlier calls.
// //
// NOTE: this is provided only for backward compatibilty with existing clients // NOTE: this is provided only for backward compatibility with existing clients
// and will likely be removed in an upcoming release. New uses should transmit // and will likely be removed in an upcoming release. New uses should transmit
// this type of data using metadata with a different, non-reserved (i.e. does // this type of data using metadata with a different, non-reserved (i.e. does
// not begin with "grpc-") header name. // not begin with "grpc-") header name.
@ -272,7 +270,7 @@ func SetTrace(ctx context.Context, b []byte) context.Context {
// Trace returns the trace from the context for the inbound RPC. // Trace returns the trace from the context for the inbound RPC.
// //
// NOTE: this is provided only for backward compatibilty with existing clients // NOTE: this is provided only for backward compatibility with existing clients
// and will likely be removed in an upcoming release. New uses should transmit // and will likely be removed in an upcoming release. New uses should transmit
// this type of data using metadata with a different, non-reserved (i.e. does // this type of data using metadata with a different, non-reserved (i.e. does
// not begin with "grpc-") header name. // not begin with "grpc-") header name.

View File

@ -27,6 +27,7 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
@ -106,10 +107,10 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
var ( var (
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
put func() done func(balancer.DoneInfo)
cancel context.CancelFunc cancel context.CancelFunc
) )
c := defaultCallInfo c := defaultCallInfo()
mc := cc.GetMethodConfig(method) mc := cc.GetMethodConfig(method)
if mc.WaitForReady != nil { if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady c.failFast = !*mc.WaitForReady
@ -126,7 +127,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
opts = append(cc.dopts.callOptions, opts...) opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts { for _, o := range opts {
if err := o.before(&c); err != nil { if err := o.before(c); err != nil {
return nil, toRPCErr(err) return nil, toRPCErr(err)
} }
} }
@ -167,7 +168,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
}() }()
} }
ctx = newContextWithRPCInfo(ctx) ctx = newContextWithRPCInfo(ctx, c.failFast)
sh := cc.dopts.copts.StatsHandler sh := cc.dopts.copts.StatsHandler
if sh != nil { if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
@ -188,11 +189,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
}() }()
} }
gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
}
for { for {
t, put, err = cc.getTransport(ctx, gopts) t, done, err = cc.getTransport(ctx, c.failFast)
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if _, ok := status.FromError(err); ok { if _, ok := status.FromError(err); ok {
@ -210,15 +208,15 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
s, err = t.NewStream(ctx, callHdr) s, err = t.NewStream(ctx, callHdr)
if err != nil { if err != nil {
if _, ok := err.(transport.ConnectionError); ok && put != nil { if _, ok := err.(transport.ConnectionError); ok && done != nil {
// If error is connection error, transport was sending data on wire, // If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire. // and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent. // If error is not connection error, we are sure nothing has been sent.
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
} }
if put != nil { if done != nil {
put() done(balancer.DoneInfo{Err: err})
put = nil done = nil
} }
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue continue
@ -240,7 +238,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
dc: cc.dopts.dc, dc: cc.dopts.dc,
cancel: cancel, cancel: cancel,
put: put, done: done,
t: t, t: t,
s: s, s: s,
p: &parser{r: s}, p: &parser{r: s},
@ -251,9 +249,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
statsCtx: ctx, statsCtx: ctx,
statsHandler: cc.dopts.copts.StatsHandler, statsHandler: cc.dopts.copts.StatsHandler,
} }
if cc.dopts.cp != nil {
cs.cbuf = new(bytes.Buffer)
}
// Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
// when there is no pending I/O operations on this stream. // when there is no pending I/O operations on this stream.
go func() { go func() {
@ -283,21 +278,20 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// clientStream implements a client side Stream. // clientStream implements a client side Stream.
type clientStream struct { type clientStream struct {
opts []CallOption opts []CallOption
c callInfo c *callInfo
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
desc *StreamDesc desc *StreamDesc
codec Codec codec Codec
cp Compressor cp Compressor
cbuf *bytes.Buffer
dc Decompressor dc Decompressor
cancel context.CancelFunc cancel context.CancelFunc
tracing bool // set to EnableTracing when the clientStream is created. tracing bool // set to EnableTracing when the clientStream is created.
mu sync.Mutex mu sync.Mutex
put func() done func(balancer.DoneInfo)
closed bool closed bool
finished bool finished bool
// trInfo.tr is set when the clientStream is created (if EnableTracing is true), // trInfo.tr is set when the clientStream is created (if EnableTracing is true),
@ -367,12 +361,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
Client: true, Client: true,
} }
} }
hdr, data, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload) hdr, data, err := encode(cs.codec, m, cs.cp, bytes.NewBuffer([]byte{}), outPayload)
defer func() {
if cs.cbuf != nil {
cs.cbuf.Reset()
}
}()
if err != nil { if err != nil {
return err return err
} }
@ -494,15 +483,15 @@ func (cs *clientStream) finish(err error) {
} }
}() }()
for _, o := range cs.opts { for _, o := range cs.opts {
o.after(&cs.c) o.after(cs.c)
} }
if cs.put != nil { if cs.done != nil {
updateRPCInfoInContext(cs.s.Context(), rpcInfo{ updateRPCInfoInContext(cs.s.Context(), rpcInfo{
bytesSent: cs.s.BytesSent(), bytesSent: cs.s.BytesSent(),
bytesReceived: cs.s.BytesReceived(), bytesReceived: cs.s.BytesReceived(),
}) })
cs.put() cs.done(balancer.DoneInfo{Err: err})
cs.put = nil cs.done = nil
} }
if cs.statsHandler != nil { if cs.statsHandler != nil {
end := &stats.End{ end := &stats.End{
@ -557,7 +546,6 @@ type serverStream struct {
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
cbuf *bytes.Buffer
maxReceiveMessageSize int maxReceiveMessageSize int
maxSendMessageSize int maxSendMessageSize int
trInfo *traceInfo trInfo *traceInfo
@ -613,12 +601,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
if ss.statsHandler != nil { if ss.statsHandler != nil {
outPayload = &stats.OutPayload{} outPayload = &stats.OutPayload{}
} }
hdr, data, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload) hdr, data, err := encode(ss.codec, m, ss.cp, bytes.NewBuffer([]byte{}), outPayload)
defer func() {
if ss.cbuf != nil {
ss.cbuf.Reset()
}
}()
if err != nil { if err != nil {
return err return err
} }

View File

@ -31,7 +31,7 @@ import (
// EnableTracing controls whether to trace RPCs using the golang.org/x/net/trace package. // EnableTracing controls whether to trace RPCs using the golang.org/x/net/trace package.
// This should only be set before any RPCs are sent or received by this program. // This should only be set before any RPCs are sent or received by this program.
var EnableTracing = true var EnableTracing bool
// methodFamily returns the trace family for the given method. // methodFamily returns the trace family for the given method.
// It turns "/pkg.Service/GetFoo" into "pkg.Service". // It turns "/pkg.Service/GetFoo" into "pkg.Service".
@ -76,6 +76,15 @@ func (f *firstLine) String() string {
return line.String() return line.String()
} }
const truncateSize = 100
func truncate(x string, l int) string {
if l > len(x) {
return x
}
return x[:l]
}
// payload represents an RPC request or response payload. // payload represents an RPC request or response payload.
type payload struct { type payload struct {
sent bool // whether this is an outgoing payload sent bool // whether this is an outgoing payload
@ -85,9 +94,9 @@ type payload struct {
func (p payload) String() string { func (p payload) String() string {
if p.sent { if p.sent {
return fmt.Sprintf("sent: %v", p.msg) return truncate(fmt.Sprintf("sent: %v", p.msg), truncateSize)
} }
return fmt.Sprintf("recv: %v", p.msg) return truncate(fmt.Sprintf("recv: %v", p.msg), truncateSize)
} }
type fmtStringer struct { type fmtStringer struct {

View File

@ -59,7 +59,7 @@ type bdpEstimator struct {
sample uint32 sample uint32
// bwMax is the maximum bandwidth noted so far (bytes/sec). // bwMax is the maximum bandwidth noted so far (bytes/sec).
bwMax float64 bwMax float64
// bool to keep track of the begining of a new measurement cycle. // bool to keep track of the beginning of a new measurement cycle.
isSent bool isSent bool
// Callback to update the window sizes. // Callback to update the window sizes.
updateFlowControl func(n uint32) updateFlowControl func(n uint32)
@ -70,7 +70,7 @@ type bdpEstimator struct {
} }
// timesnap registers the time bdp ping was sent out so that // timesnap registers the time bdp ping was sent out so that
// network rtt can be calculated when its ack is recieved. // network rtt can be calculated when its ack is received.
// It is called (by controller) when the bdpPing is // It is called (by controller) when the bdpPing is
// being written on the wire. // being written on the wire.
func (b *bdpEstimator) timesnap(d [8]byte) { func (b *bdpEstimator) timesnap(d [8]byte) {
@ -119,7 +119,7 @@ func (b *bdpEstimator) calculate(d [8]byte) {
b.rtt += (rttSample - b.rtt) * float64(alpha) b.rtt += (rttSample - b.rtt) * float64(alpha)
} }
b.isSent = false b.isSent = false
// The number of bytes accumalated so far in the sample is smaller // The number of bytes accumulated so far in the sample is smaller
// than or equal to 1.5 times the real BDP on a saturated connection. // than or equal to 1.5 times the real BDP on a saturated connection.
bwCurrent := float64(b.sample) / (b.rtt * float64(1.5)) bwCurrent := float64(b.sample) / (b.rtt * float64(1.5))
if bwCurrent > b.bwMax { if bwCurrent > b.bwMax {

View File

@ -22,9 +22,11 @@ import (
"fmt" "fmt"
"math" "math"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
) )
const ( const (
@ -44,15 +46,44 @@ const (
defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute) defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute)
// max window limit set by HTTP2 Specs. // max window limit set by HTTP2 Specs.
maxWindowSize = math.MaxInt32 maxWindowSize = math.MaxInt32
// defaultLocalSendQuota sets is default value for number of data
// bytes that each stream can schedule before some of it being
// flushed out.
defaultLocalSendQuota = 64 * 1024
) )
// The following defines various control items which could flow through // The following defines various control items which could flow through
// the control buffer of transport. They represent different aspects of // the control buffer of transport. They represent different aspects of
// control tasks, e.g., flow control, settings, streaming resetting, etc. // control tasks, e.g., flow control, settings, streaming resetting, etc.
type headerFrame struct {
streamID uint32
hf []hpack.HeaderField
endStream bool
}
func (*headerFrame) item() {}
type continuationFrame struct {
streamID uint32
endHeaders bool
headerBlockFragment []byte
}
type dataFrame struct {
streamID uint32
endStream bool
d []byte
f func()
}
func (*dataFrame) item() {}
func (*continuationFrame) item() {}
type windowUpdate struct { type windowUpdate struct {
streamID uint32 streamID uint32
increment uint32 increment uint32
flush bool
} }
func (*windowUpdate) item() {} func (*windowUpdate) item() {}
@ -98,6 +129,7 @@ type quotaPool struct {
c chan int c chan int
mu sync.Mutex mu sync.Mutex
version uint32
quota int quota int
} }
@ -119,6 +151,10 @@ func newQuotaPool(q int) *quotaPool {
func (qb *quotaPool) add(v int) { func (qb *quotaPool) add(v int) {
qb.mu.Lock() qb.mu.Lock()
defer qb.mu.Unlock() defer qb.mu.Unlock()
qb.lockedAdd(v)
}
func (qb *quotaPool) lockedAdd(v int) {
select { select {
case n := <-qb.c: case n := <-qb.c:
qb.quota += n qb.quota += n
@ -139,6 +175,35 @@ func (qb *quotaPool) add(v int) {
} }
} }
func (qb *quotaPool) addAndUpdate(v int) {
qb.mu.Lock()
defer qb.mu.Unlock()
qb.lockedAdd(v)
// Update the version only after having added to the quota
// so that if acquireWithVesrion sees the new vesrion it is
// guaranteed to have seen the updated quota.
// Also, still keep this inside of the lock, so that when
// compareAndExecute is processing, this function doesn't
// get executed partially (quota gets updated but the version
// doesn't).
atomic.AddUint32(&(qb.version), 1)
}
func (qb *quotaPool) acquireWithVersion() (<-chan int, uint32) {
return qb.c, atomic.LoadUint32(&(qb.version))
}
func (qb *quotaPool) compareAndExecute(version uint32, success, failure func()) bool {
qb.mu.Lock()
defer qb.mu.Unlock()
if version == atomic.LoadUint32(&(qb.version)) {
success()
return true
}
failure()
return false
}
// acquire returns the channel on which available quota amounts are sent. // acquire returns the channel on which available quota amounts are sent.
func (qb *quotaPool) acquire() <-chan int { func (qb *quotaPool) acquire() <-chan int {
return qb.c return qb.c

View File

@ -1,45 +0,0 @@
// +build go1.6,!go1.7
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"net"
"google.golang.org/grpc/codes"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
}
// ContextErr converts the error from context package into a StreamError.
func ContextErr(err error) StreamError {
switch err {
case context.DeadlineExceeded:
return streamErrorf(codes.DeadlineExceeded, "%v", err)
case context.Canceled:
return streamErrorf(codes.Canceled, "%v", err)
}
return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err)
}

View File

@ -1,46 +0,0 @@
// +build go1.7
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"context"
"net"
"google.golang.org/grpc/codes"
netctx "golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, network, address)
}
// ContextErr converts the error from context package into a StreamError.
func ContextErr(err error) StreamError {
switch err {
case context.DeadlineExceeded, netctx.DeadlineExceeded:
return streamErrorf(codes.DeadlineExceeded, "%v", err)
case context.Canceled, netctx.Canceled:
return streamErrorf(codes.Canceled, "%v", err)
}
return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err)
}

View File

@ -173,7 +173,6 @@ func (ht *serverHandlerTransport) do(fn func()) error {
case <-ht.closedCh: case <-ht.closedCh:
return ErrConnClosing return ErrConnClosing
} }
} }
} }
@ -183,6 +182,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
ht.mu.Unlock() ht.mu.Unlock()
return nil return nil
} }
ht.streamDone = true
ht.mu.Unlock() ht.mu.Unlock()
err := ht.do(func() { err := ht.do(func() {
ht.writeCommonHeaders(s) ht.writeCommonHeaders(s)
@ -223,9 +223,6 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
} }
}) })
close(ht.writes) close(ht.writes)
ht.mu.Lock()
ht.streamDone = true
ht.mu.Unlock()
return err return err
} }

View File

@ -43,6 +43,7 @@ import (
// http2Client implements the ClientTransport interface with HTTP2. // http2Client implements the ClientTransport interface with HTTP2.
type http2Client struct { type http2Client struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc
target string // server name/addr target string // server name/addr
userAgent string userAgent string
md interface{} md interface{}
@ -52,17 +53,6 @@ type http2Client struct {
authInfo credentials.AuthInfo // auth info about the connection authInfo credentials.AuthInfo // auth info about the connection
nextID uint32 // the next stream ID to be used nextID uint32 // the next stream ID to be used
// writableChan synchronizes write access to the transport.
// A writer acquires the write lock by sending a value on writableChan
// and releases it by receiving from writableChan.
writableChan chan int
// shutdownChan is closed when Close is called.
// Blocking operations should select on shutdownChan to avoid
// blocking forever after Close.
// TODO(zhaoq): Maybe have a channel context?
shutdownChan chan struct{}
// errorChan is closed to notify the I/O error to the caller.
errorChan chan struct{}
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport. // that the server sent GoAway on this transport.
goAway chan struct{} goAway chan struct{}
@ -119,7 +109,7 @@ func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error
if fn != nil { if fn != nil {
return fn(ctx, addr) return fn(ctx, addr)
} }
return dialContext(ctx, "tcp", addr) return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
} }
func isTemporary(err error) bool { func isTemporary(err error) bool {
@ -153,9 +143,18 @@ func isTemporary(err error) bool {
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction // and starts to receive messages on it. Non-nil error returns if construction
// fails. // fails.
func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (_ ClientTransport, err error) { func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, timeout time.Duration) (_ ClientTransport, err error) {
scheme := "http" scheme := "http"
conn, err := dial(ctx, opts.Dialer, addr.Addr) ctx, cancel := context.WithCancel(ctx)
connectCtx, connectCancel := context.WithTimeout(ctx, timeout)
defer func() {
connectCancel()
if err != nil {
cancel()
}
}()
conn, err := dial(connectCtx, opts.Dialer, addr.Addr)
if err != nil { if err != nil {
if opts.FailOnNonTempDialError { if opts.FailOnNonTempDialError {
return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err)
@ -174,7 +173,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
) )
if creds := opts.TransportCredentials; creds != nil { if creds := opts.TransportCredentials; creds != nil {
scheme = "https" scheme = "https"
conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn) conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Addr, conn)
if err != nil { if err != nil {
// Credentials handshake errors are typically considered permanent // Credentials handshake errors are typically considered permanent
// to avoid retrying on e.g. bad certificates. // to avoid retrying on e.g. bad certificates.
@ -198,8 +197,17 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
dynamicWindow = false dynamicWindow = false
} }
var buf bytes.Buffer var buf bytes.Buffer
writeBufSize := defaultWriteBufSize
if opts.WriteBufferSize > 0 {
writeBufSize = opts.WriteBufferSize
}
readBufSize := defaultReadBufSize
if opts.ReadBufferSize > 0 {
readBufSize = opts.ReadBufferSize
}
t := &http2Client{ t := &http2Client{
ctx: ctx, ctx: ctx,
cancel: cancel,
target: addr.Addr, target: addr.Addr,
userAgent: opts.UserAgent, userAgent: opts.UserAgent,
md: addr.Metadata, md: addr.Metadata,
@ -209,14 +217,11 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
authInfo: authInfo, authInfo: authInfo,
// The client initiated stream id is odd starting from 1. // The client initiated stream id is odd starting from 1.
nextID: 1, nextID: 1,
writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}),
goAway: make(chan struct{}), goAway: make(chan struct{}),
awakenKeepalive: make(chan struct{}, 1), awakenKeepalive: make(chan struct{}, 1),
framer: newFramer(conn),
hBuf: &buf, hBuf: &buf,
hEnc: hpack.NewEncoder(&buf), hEnc: hpack.NewEncoder(&buf),
framer: newFramer(conn, writeBufSize, readBufSize),
controlBuf: newControlBuffer(), controlBuf: newControlBuffer(),
fc: &inFlow{limit: uint32(icwz)}, fc: &inFlow{limit: uint32(icwz)},
sendQuotaPool: newQuotaPool(defaultWindowSize), sendQuotaPool: newQuotaPool(defaultWindowSize),
@ -270,12 +275,12 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
} }
if t.initialWindowSize != defaultWindowSize { if t.initialWindowSize != defaultWindowSize {
err = t.framer.writeSettings(true, http2.Setting{ err = t.framer.fr.WriteSettings(http2.Setting{
ID: http2.SettingInitialWindowSize, ID: http2.SettingInitialWindowSize,
Val: uint32(t.initialWindowSize), Val: uint32(t.initialWindowSize),
}) })
} else { } else {
err = t.framer.writeSettings(true) err = t.framer.fr.WriteSettings()
} }
if err != nil { if err != nil {
t.Close() t.Close()
@ -283,16 +288,19 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
} }
// Adjust the connection flow control window if needed. // Adjust the connection flow control window if needed.
if delta := uint32(icwz - defaultWindowSize); delta > 0 { if delta := uint32(icwz - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { if err := t.framer.fr.WriteWindowUpdate(0, delta); err != nil {
t.Close() t.Close()
return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err) return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err)
} }
} }
go t.controller() t.framer.writer.Flush()
go func() {
loopyWriter(t.ctx, t.controlBuf, t.itemHandler)
t.Close()
}()
if t.kp.Time != infinity { if t.kp.Time != infinity {
go t.keepalive() go t.keepalive()
} }
t.writableChan <- 0
return t, nil return t, nil
} }
@ -307,6 +315,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
buf: newRecvBuffer(), buf: newRecvBuffer(),
fc: &inFlow{limit: uint32(t.initialWindowSize)}, fc: &inFlow{limit: uint32(t.initialWindowSize)},
sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), sendQuotaPool: newQuotaPool(int(t.streamSendQuota)),
localSendQuota: newQuotaPool(defaultLocalSendQuota),
headerChan: make(chan struct{}), headerChan: make(chan struct{}),
} }
t.nextID += 2 t.nextID += 2
@ -368,13 +377,13 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
authData[k] = v authData[k] = v
} }
} }
callAuthData := make(map[string]string) callAuthData := map[string]string{}
// Check if credentials.PerRPCCredentials were provided via call options. // Check if credentials.PerRPCCredentials were provided via call options.
// Note: if these credentials are provided both via dial options and call // Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied. // options, then both sets of credentials will be applied.
if callCreds := callHdr.Creds; callCreds != nil { if callCreds := callHdr.Creds; callCreds != nil {
if !t.isSecure && callCreds.RequireTransportSecurity() { if !t.isSecure && callCreds.RequireTransportSecurity() {
return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure conneciton") return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
} }
data, err := callCreds.GetRequestMetadata(ctx, audience) data, err := callCreds.GetRequestMetadata(ctx, audience)
if err != nil { if err != nil {
@ -400,7 +409,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, ErrConnClosing return nil, ErrConnClosing
} }
t.mu.Unlock() t.mu.Unlock()
sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) sq, err := wait(ctx, t.ctx, nil, nil, t.streamsQuota.acquire())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -408,19 +417,66 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if sq > 1 { if sq > 1 {
t.streamsQuota.add(sq - 1) t.streamsQuota.add(sq - 1)
} }
if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
// Return the quota back now because there is no stream returned to the caller. // first and create a slice of that exact size.
if _, ok := err.(StreamError); ok { // Make the slice of certain predictable size to reduce allocations made by append.
t.streamsQuota.add(1) hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
hfLen += len(authData) + len(callAuthData)
headerFields := make([]hpack.HeaderField, 0, hfLen)
headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"})
headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method})
headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"})
if callHdr.SendCompress != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
}
if dl, ok := ctx.Deadline(); ok {
// Send out timeout regardless its value. The server can detect timeout context by itself.
// TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire.
timeout := dl.Sub(time.Now())
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
}
for k, v := range authData {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
for k, v := range callAuthData {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
if b := stats.OutgoingTags(ctx); b != nil {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)})
}
if b := stats.OutgoingTrace(ctx); b != nil {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)})
}
if md, ok := metadata.FromOutgoingContext(ctx); ok {
for k, vv := range md {
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) {
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
}
if md, ok := t.md.(*metadata.MD); ok {
for k, vv := range *md {
if isReservedHeader(k) {
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
} }
return nil, err
} }
t.mu.Lock() t.mu.Lock()
if t.state == draining { if t.state == draining {
t.mu.Unlock() t.mu.Unlock()
t.streamsQuota.add(1) t.streamsQuota.add(1)
// Need to make t writable again so that the rpc in flight can still proceed.
t.writableChan <- 0
return nil, ErrStreamDrain return nil, ErrStreamDrain
} }
if t.state != reachable { if t.state != reachable {
@ -434,7 +490,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if len(t.activeStreams) == 1 { if len(t.activeStreams) == 1 {
select { select {
case t.awakenKeepalive <- struct{}{}: case t.awakenKeepalive <- struct{}{}:
t.framer.writePing(false, false, [8]byte{}) t.controlBuf.put(&ping{data: [8]byte{}})
// Fill the awakenKeepalive channel again as this channel must be // Fill the awakenKeepalive channel again as this channel must be
// kept non-writable except at the point that the keepalive() // kept non-writable except at the point that the keepalive()
// goroutine is waiting either to be awaken or shutdown. // goroutine is waiting either to be awaken or shutdown.
@ -442,102 +498,13 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
default: default:
} }
} }
t.controlBuf.put(&headerFrame{
streamID: s.id,
hf: headerFields,
endStream: false,
})
t.mu.Unlock() t.mu.Unlock()
// HPACK encodes various headers. Note that once WriteField(...) is
// called, the corresponding headers/continuation frame has to be sent
// because hpack.Encoder is stateful.
t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
t.hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme})
t.hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method})
t.hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})
if callHdr.SendCompress != "" {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
}
if dl, ok := ctx.Deadline(); ok {
// Send out timeout regardless its value. The server can detect timeout context by itself.
timeout := dl.Sub(time.Now())
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
}
for k, v := range authData {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
for k, v := range callAuthData {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
var (
endHeaders bool
)
if b := stats.OutgoingTags(ctx); b != nil {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)})
}
if b := stats.OutgoingTrace(ctx); b != nil {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)})
}
if md, ok := metadata.FromOutgoingContext(ctx); ok {
for k, vv := range md {
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) {
continue
}
for _, v := range vv {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
}
if md, ok := t.md.(*metadata.MD); ok {
for k, vv := range *md {
if isReservedHeader(k) {
continue
}
for _, v := range vv {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
}
first := true
bufLen := t.hBuf.Len()
// Sends the headers in a single batch even when they span multiple frames.
for !endHeaders {
size := t.hBuf.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
var flush bool
if callHdr.Flush && endHeaders {
flush = true
}
if first {
// Sends a HeadersFrame to server to start a new stream.
p := http2.HeadersFrameParam{
StreamID: s.id,
BlockFragment: t.hBuf.Next(size),
EndStream: false,
EndHeaders: endHeaders,
}
// Do a force flush for the buffered frames iff it is the last headers frame
// and there is header metadata to be sent. Otherwise, there is flushing until
// the corresponding data frame is written.
err = t.framer.writeHeaders(flush, p)
first = false
} else {
// Sends Continuation frames for the leftover headers.
err = t.framer.writeContinuation(flush, s.id, endHeaders, t.hBuf.Next(size))
}
if err != nil {
t.notifyError(err)
return nil, connectionErrorf(true, err, "transport: %v", err)
}
}
s.mu.Lock() s.mu.Lock()
s.bytesSent = true s.bytesSent = true
s.mu.Unlock() s.mu.Unlock()
@ -545,7 +512,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if t.statsHandler != nil { if t.statsHandler != nil {
outHeader := &stats.OutHeader{ outHeader := &stats.OutHeader{
Client: true, Client: true,
WireLength: bufLen,
FullMethod: callHdr.Method, FullMethod: callHdr.Method,
RemoteAddr: t.remoteAddr, RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr, LocalAddr: t.localAddr,
@ -553,7 +519,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
t.statsHandler.HandleRPC(s.ctx, outHeader) t.statsHandler.HandleRPC(s.ctx, outHeader)
} }
t.writableChan <- 0
return s, nil return s, nil
} }
@ -623,12 +588,9 @@ func (t *http2Client) Close() (err error) {
t.mu.Unlock() t.mu.Unlock()
return return
} }
if t.state == reachable || t.state == draining {
close(t.errorChan)
}
t.state = closing t.state = closing
t.mu.Unlock() t.mu.Unlock()
close(t.shutdownChan) t.cancel()
err = t.conn.Close() err = t.conn.Close()
t.mu.Lock() t.mu.Lock()
streams := t.activeStreams streams := t.activeStreams
@ -650,23 +612,18 @@ func (t *http2Client) Close() (err error) {
} }
t.statsHandler.HandleConn(t.ctx, connEnd) t.statsHandler.HandleConn(t.ctx, connEnd)
} }
return return err
} }
// GracefulClose sets the state to draining, which prevents new streams from
// being created and causes the transport to be closed when the last active
// stream is closed. If there are no active streams, the transport is closed
// immediately. This does nothing if the transport is already draining or
// closing.
func (t *http2Client) GracefulClose() error { func (t *http2Client) GracefulClose() error {
t.mu.Lock() t.mu.Lock()
switch t.state { switch t.state {
case unreachable: case closing, draining:
// The server may close the connection concurrently. t is not available for
// any streams. Close it now.
t.mu.Unlock()
t.Close()
return nil
case closing:
t.mu.Unlock()
return nil
}
if t.state == draining {
t.mu.Unlock() t.mu.Unlock()
return nil return nil
} }
@ -681,32 +638,38 @@ func (t *http2Client) GracefulClose() error {
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller // Write formats the data into HTTP2 data frame(s) and sends it out. The caller
// should proceed only if Write returns nil. // should proceed only if Write returns nil.
// TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later
// if it improves the performance.
func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
secondStart := http2MaxFrameLen - len(hdr)%http2MaxFrameLen select {
if len(data) < secondStart { case <-s.ctx.Done():
secondStart = len(data) return ContextErr(s.ctx.Err())
case <-t.ctx.Done():
return ErrConnClosing
default:
} }
hdr = append(hdr, data[:secondStart]...)
data = data[secondStart:] if hdr == nil && data == nil && opts.Last {
isLastSlice := (len(data) == 0) // stream.CloseSend uses this to send an empty frame with endStream=True
r := bytes.NewBuffer(hdr) t.controlBuf.put(&dataFrame{streamID: s.id, endStream: true, f: func() {}})
var ( return nil
p []byte }
oqv uint32 // Add data to header frame so that we can equally distribute data across frames.
) emptyLen := http2MaxFrameLen - len(hdr)
for { if emptyLen > len(data) {
oqv = atomic.LoadUint32(&t.outQuotaVersion) emptyLen = len(data)
if r.Len() > 0 || p != nil { }
hdr = append(hdr, data[:emptyLen]...)
data = data[emptyLen:]
for idx, r := range [][]byte{hdr, data} {
for len(r) > 0 {
size := http2MaxFrameLen size := http2MaxFrameLen
// Wait until the stream has some quota to send the data. // Wait until the stream has some quota to send the data.
sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire()) quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion()
sq, err := wait(s.ctx, t.ctx, s.done, s.goAway, quotaChan)
if err != nil { if err != nil {
return err return err
} }
// Wait until the transport has some quota to send the data. // Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire()) tq, err := wait(s.ctx, t.ctx, s.done, s.goAway, t.sendQuotaPool.acquire())
if err != nil { if err != nil {
return err return err
} }
@ -716,93 +679,51 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
if tq < size { if tq < size {
size = tq size = tq
} }
if p == nil { if size > len(r) {
p = r.Next(size) size = len(r)
} }
p := r[:size]
ps := len(p) ps := len(p)
if ps < sq {
// Overbooked stream quota. Return it back.
s.sendQuotaPool.add(sq - ps)
}
if ps < tq { if ps < tq {
// Overbooked transport quota. Return it back. // Overbooked transport quota. Return it back.
t.sendQuotaPool.add(tq - ps) t.sendQuotaPool.add(tq - ps)
} }
} // Acquire local send quota to be able to write to the controlBuf.
var ( ltq, err := wait(s.ctx, t.ctx, s.done, s.goAway, s.localSendQuota.acquire())
endStream bool if err != nil {
forceFlush bool if _, ok := err.(ConnectionError); !ok {
) t.sendQuotaPool.add(ps)
// Indicate there is a writer who is about to write a data frame.
t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport.
if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok || err == io.EOF {
// Return the connection quota back.
t.sendQuotaPool.add(len(p))
}
if t.framer.adjustNumWriters(-1) == 0 {
// This writer is the last one in this batch and has the
// responsibility to flush the buffered frames. It queues
// a flush request to controlBuf instead of flushing directly
// in order to avoid the race with other writing or flushing.
t.controlBuf.put(&flushIO{})
} }
return err return err
} }
select { s.localSendQuota.add(ltq - ps) // It's ok if we make it negative.
case <-s.ctx.Done(): var endStream bool
t.sendQuotaPool.add(len(p)) // See if this is the last frame to be written.
if t.framer.adjustNumWriters(-1) == 0 {
t.controlBuf.put(&flushIO{})
}
t.writableChan <- 0
return ContextErr(s.ctx.Err())
default:
}
if oqv != atomic.LoadUint32(&t.outQuotaVersion) {
// InitialWindowSize settings frame must have been received after we
// acquired send quota but before we got the writable channel.
// We must forsake this write.
t.sendQuotaPool.add(len(p))
s.sendQuotaPool.add(len(p))
if t.framer.adjustNumWriters(-1) == 0 {
t.controlBuf.put(&flushIO{})
}
t.writableChan <- 0
continue
}
if r.Len() == 0 {
if isLastSlice {
if opts.Last { if opts.Last {
if len(r)-size == 0 { // No more data in r after this iteration.
if idx == 0 { // We're writing data header.
if len(data) == 0 { // There's no data to follow.
endStream = true endStream = true
} }
if t.framer.adjustNumWriters(0) == 1 { } else { // We're writing data.
// Do a force flush iff this is last frame for the entire gRPC message endStream = true
// and the caller is the only writer at this moment.
forceFlush = true
}
} else {
isLastSlice = true
if len(data) != 0 {
r = bytes.NewBuffer(data)
} }
} }
} }
// If WriteData fails, all the pending streams will be handled success := func() {
// by http2Client.Close(). No explicit CloseStream() needs to be t.controlBuf.put(&dataFrame{streamID: s.id, endStream: endStream, d: p, f: func() { s.localSendQuota.add(ps) }})
// invoked. if ps < sq {
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { s.sendQuotaPool.lockedAdd(sq - ps)
t.notifyError(err)
return connectionErrorf(true, err, "transport: %v", err)
} }
p = nil r = r[ps:]
if t.framer.adjustNumWriters(-1) == 0 { }
t.framer.flushWrite() failure := func() {
s.sendQuotaPool.lockedAdd(sq)
}
if !s.sendQuotaPool.compareAndExecute(quotaVer, success, failure) {
t.sendQuotaPool.add(ps)
s.localSendQuota.add(ps)
} }
t.writableChan <- 0
if r.Len() == 0 {
break
} }
} }
if !opts.Last { if !opts.Last {
@ -833,11 +754,11 @@ func (t *http2Client) adjustWindow(s *Stream, n uint32) {
return return
} }
if w := s.fc.maybeAdjust(n); w > 0 { if w := s.fc.maybeAdjust(n); w > 0 {
// Piggyback conneciton's window update along. // Piggyback connection's window update along.
if cw := t.fc.resetPendingUpdate(); cw > 0 { if cw := t.fc.resetPendingUpdate(); cw > 0 {
t.controlBuf.put(&windowUpdate{0, cw, false}) t.controlBuf.put(&windowUpdate{0, cw})
} }
t.controlBuf.put(&windowUpdate{s.id, w, true}) t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
@ -852,9 +773,9 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
} }
if w := s.fc.onRead(n); w > 0 { if w := s.fc.onRead(n); w > 0 {
if cw := t.fc.resetPendingUpdate(); cw > 0 { if cw := t.fc.resetPendingUpdate(); cw > 0 {
t.controlBuf.put(&windowUpdate{0, cw, false}) t.controlBuf.put(&windowUpdate{0, cw})
} }
t.controlBuf.put(&windowUpdate{s.id, w, true}) t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
@ -868,7 +789,7 @@ func (t *http2Client) updateFlowControl(n uint32) {
} }
t.initialWindowSize = int32(n) t.initialWindowSize = int32(n)
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n), false}) t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n)})
t.controlBuf.put(&settings{ t.controlBuf.put(&settings{
ack: false, ack: false,
ss: []http2.Setting{ ss: []http2.Setting{
@ -898,15 +819,17 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// Furthermore, if a bdpPing is being sent out we can piggyback // Furthermore, if a bdpPing is being sent out we can piggyback
// connection's window update for the bytes we just received. // connection's window update for the bytes we just received.
if sendBDPPing { if sendBDPPing {
t.controlBuf.put(&windowUpdate{0, uint32(size), false}) if size != 0 { // Could've been an empty data frame.
t.controlBuf.put(&windowUpdate{0, uint32(size)})
}
t.controlBuf.put(bdpPing) t.controlBuf.put(bdpPing)
} else { } else {
if err := t.fc.onData(uint32(size)); err != nil { if err := t.fc.onData(uint32(size)); err != nil {
t.notifyError(connectionErrorf(true, err, "%v", err)) t.Close()
return return
} }
if w := t.fc.onRead(uint32(size)); w > 0 { if w := t.fc.onRead(uint32(size)); w > 0 {
t.controlBuf.put(&windowUpdate{0, w, true}) t.controlBuf.put(&windowUpdate{0, w})
} }
} }
// Select the right stream to dispatch. // Select the right stream to dispatch.
@ -930,7 +853,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
} }
if f.Header().Flags.Has(http2.FlagDataPadded) { if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 {
t.controlBuf.put(&windowUpdate{s.id, w, true}) t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
s.mu.Unlock() s.mu.Unlock()
@ -1019,10 +942,10 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
id := f.LastStreamID id := f.LastStreamID
if id > 0 && id%2 != 1 { if id > 0 && id%2 != 1 {
t.mu.Unlock() t.mu.Unlock()
t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) t.Close()
return return
} }
// A client can recieve multiple GoAways from server (look at https://github.com/grpc/grpc-go/issues/1387). // A client can receive multiple GoAways from server (look at https://github.com/grpc/grpc-go/issues/1387).
// The idea is that the first GoAway will be sent with an ID of MaxInt32 and the second GoAway will be sent after an RTT delay // The idea is that the first GoAway will be sent with an ID of MaxInt32 and the second GoAway will be sent after an RTT delay
// with the ID of the last stream the server will process. // with the ID of the last stream the server will process.
// Therefore, when we get the first GoAway we don't really close any streams. While in case of second GoAway we // Therefore, when we get the first GoAway we don't really close any streams. While in case of second GoAway we
@ -1033,7 +956,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// If there are multiple GoAways the first one should always have an ID greater than the following ones. // If there are multiple GoAways the first one should always have an ID greater than the following ones.
if id > t.prevGoAwayID { if id > t.prevGoAwayID {
t.mu.Unlock() t.mu.Unlock()
t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) t.Close()
return return
} }
default: default:
@ -1177,22 +1100,22 @@ func handleMalformedHTTP2(s *Stream, err error) {
// TODO(zhaoq): Check the validity of the incoming frame sequence. // TODO(zhaoq): Check the validity of the incoming frame sequence.
func (t *http2Client) reader() { func (t *http2Client) reader() {
// Check the validity of server preface. // Check the validity of server preface.
frame, err := t.framer.readFrame() frame, err := t.framer.fr.ReadFrame()
if err != nil { if err != nil {
t.notifyError(err) t.Close()
return return
} }
atomic.CompareAndSwapUint32(&t.activity, 0, 1) atomic.CompareAndSwapUint32(&t.activity, 0, 1)
sf, ok := frame.(*http2.SettingsFrame) sf, ok := frame.(*http2.SettingsFrame)
if !ok { if !ok {
t.notifyError(err) t.Close()
return return
} }
t.handleSettings(sf) t.handleSettings(sf)
// loop to keep reading incoming messages on this transport. // loop to keep reading incoming messages on this transport.
for { for {
frame, err := t.framer.readFrame() frame, err := t.framer.fr.ReadFrame()
atomic.CompareAndSwapUint32(&t.activity, 0, 1) atomic.CompareAndSwapUint32(&t.activity, 0, 1)
if err != nil { if err != nil {
// Abort an active stream if the http2.Framer returns a // Abort an active stream if the http2.Framer returns a
@ -1204,12 +1127,12 @@ func (t *http2Client) reader() {
t.mu.Unlock() t.mu.Unlock()
if s != nil { if s != nil {
// use error detail to provide better err message // use error detail to provide better err message
handleMalformedHTTP2(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.errorDetail())) handleMalformedHTTP2(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.fr.ErrorDetail()))
} }
continue continue
} else { } else {
// Transport error. // Transport error.
t.notifyError(err) t.Close()
return return
} }
} }
@ -1253,33 +1176,66 @@ func (t *http2Client) applySettings(ss []http2.Setting) {
t.mu.Lock() t.mu.Lock()
for _, stream := range t.activeStreams { for _, stream := range t.activeStreams {
// Adjust the sending quota for each stream. // Adjust the sending quota for each stream.
stream.sendQuotaPool.add(int(s.Val) - int(t.streamSendQuota)) stream.sendQuotaPool.addAndUpdate(int(s.Val) - int(t.streamSendQuota))
} }
t.streamSendQuota = s.Val t.streamSendQuota = s.Val
t.mu.Unlock() t.mu.Unlock()
atomic.AddUint32(&t.outQuotaVersion, 1)
} }
} }
} }
// controller running in a separate goroutine takes charge of sending control // TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
// frames (e.g., window update, reset stream, setting, etc.) to the server. // is duplicated between the client and the server.
func (t *http2Client) controller() { // The transport layer needs to be refactored to take care of this.
for { func (t *http2Client) itemHandler(i item) error {
select { var err error
case i := <-t.controlBuf.get():
t.controlBuf.load()
select {
case <-t.writableChan:
switch i := i.(type) { switch i := i.(type) {
case *dataFrame:
err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d)
if err == nil {
i.f()
}
case *headerFrame:
t.hBuf.Reset()
for _, f := range i.hf {
t.hEnc.WriteField(f)
}
endHeaders := false
first := true
for !endHeaders {
size := t.hBuf.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
if first {
first = false
err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: i.streamID,
BlockFragment: t.hBuf.Next(size),
EndStream: i.endStream,
EndHeaders: endHeaders,
})
} else {
err = t.framer.fr.WriteContinuation(
i.streamID,
endHeaders,
t.hBuf.Next(size),
)
}
if err != nil {
return err
}
}
case *windowUpdate: case *windowUpdate:
t.framer.writeWindowUpdate(i.flush, i.streamID, i.increment) err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
case *settings: case *settings:
if i.ack { if i.ack {
t.framer.writeSettingsAck(true)
t.applySettings(i.ss) t.applySettings(i.ss)
err = t.framer.fr.WriteSettingsAck()
} else { } else {
t.framer.writeSettings(true, i.ss...) err = t.framer.fr.WriteSettings(i.ss...)
} }
case *resetStream: case *resetStream:
// If the server needs to be to intimated about stream closing, // If the server needs to be to intimated about stream closing,
@ -1287,27 +1243,19 @@ func (t *http2Client) controller() {
// the wire before the headers of the next stream waiting on // the wire before the headers of the next stream waiting on
// streamQuota. We ensure this by adding to the streamsQuota pool // streamQuota. We ensure this by adding to the streamsQuota pool
// only after having acquired the writableChan to send RST_STREAM. // only after having acquired the writableChan to send RST_STREAM.
err = t.framer.fr.WriteRSTStream(i.streamID, i.code)
t.streamsQuota.add(1) t.streamsQuota.add(1)
t.framer.writeRSTStream(true, i.streamID, i.code)
case *flushIO: case *flushIO:
t.framer.flushWrite() err = t.framer.writer.Flush()
case *ping: case *ping:
if !i.ack { if !i.ack {
t.bdpEst.timesnap(i.data) t.bdpEst.timesnap(i.data)
} }
t.framer.writePing(true, i.ack, i.data) err = t.framer.fr.WritePing(i.ack, i.data)
default: default:
errorf("transport: http2Client.controller got unexpected item type %v\n", i) errorf("transport: http2Client.controller got unexpected item type %v\n", i)
} }
t.writableChan <- 0 return err
continue
case <-t.shutdownChan:
return
}
case <-t.shutdownChan:
return
}
}
} }
// keepalive running in a separate goroutune makes sure the connection is alive by sending pings. // keepalive running in a separate goroutune makes sure the connection is alive by sending pings.
@ -1331,7 +1279,7 @@ func (t *http2Client) keepalive() {
case <-t.awakenKeepalive: case <-t.awakenKeepalive:
// If the control gets here a ping has been sent // If the control gets here a ping has been sent
// need to reset the timer with keepalive.Timeout. // need to reset the timer with keepalive.Timeout.
case <-t.shutdownChan: case <-t.ctx.Done():
return return
} }
} else { } else {
@ -1350,13 +1298,13 @@ func (t *http2Client) keepalive() {
} }
t.Close() t.Close()
return return
case <-t.shutdownChan: case <-t.ctx.Done():
if !timer.Stop() { if !timer.Stop() {
<-timer.C <-timer.C
} }
return return
} }
case <-t.shutdownChan: case <-t.ctx.Done():
if !timer.Stop() { if !timer.Stop() {
<-timer.C <-timer.C
} }
@ -1366,25 +1314,9 @@ func (t *http2Client) keepalive() {
} }
func (t *http2Client) Error() <-chan struct{} { func (t *http2Client) Error() <-chan struct{} {
return t.errorChan return t.ctx.Done()
} }
func (t *http2Client) GoAway() <-chan struct{} { func (t *http2Client) GoAway() <-chan struct{} {
return t.goAway return t.goAway
} }
func (t *http2Client) notifyError(err error) {
t.mu.Lock()
// make sure t.errorChan is closed only once.
if t.state == draining {
t.mu.Unlock()
t.Close()
return
}
if t.state == reachable {
t.state = unreachable
close(t.errorChan)
infof("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
}
t.mu.Unlock()
}

View File

@ -21,6 +21,7 @@ package transport
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"io" "io"
"math" "math"
"math/rand" "math/rand"
@ -51,20 +52,13 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe
// http2Server implements the ServerTransport interface with HTTP2. // http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct { type http2Server struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc
conn net.Conn conn net.Conn
remoteAddr net.Addr remoteAddr net.Addr
localAddr net.Addr localAddr net.Addr
maxStreamID uint32 // max stream ID ever seen maxStreamID uint32 // max stream ID ever seen
authInfo credentials.AuthInfo // auth info about the connection authInfo credentials.AuthInfo // auth info about the connection
inTapHandle tap.ServerInHandle inTapHandle tap.ServerInHandle
// writableChan synchronizes write access to the transport.
// A writer acquires the write lock by receiving a value on writableChan
// and releases it by sending on writableChan.
writableChan chan int
// shutdownChan is closed when Close is called.
// Blocking operations should select on shutdownChan to avoid
// blocking forever after Close.
shutdownChan chan struct{}
framer *framer framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding hBuf *bytes.Buffer // the buffer for HPACK encoding
hEnc *hpack.Encoder // HPACK encoder hEnc *hpack.Encoder // HPACK encoder
@ -96,8 +90,6 @@ type http2Server struct {
initialWindowSize int32 initialWindowSize int32
bdpEst *bdpEstimator bdpEst *bdpEstimator
outQuotaVersion uint32
mu sync.Mutex // guard the following mu sync.Mutex // guard the following
// drainChan is initialized when drain(...) is called the first time. // drainChan is initialized when drain(...) is called the first time.
@ -112,7 +104,7 @@ type http2Server struct {
// the per-stream outbound flow control window size set by the peer. // the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32 streamSendQuota uint32
// idle is the time instant when the connection went idle. // idle is the time instant when the connection went idle.
// This is either the begining of the connection or when the number of // This is either the beginning of the connection or when the number of
// RPCs go down to 0. // RPCs go down to 0.
// When the connection is busy, this value is set to 0. // When the connection is busy, this value is set to 0.
idle time.Time idle time.Time
@ -121,7 +113,15 @@ type http2Server struct {
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
// returned if something goes wrong. // returned if something goes wrong.
func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) {
framer := newFramer(conn) writeBufSize := defaultWriteBufSize
if config.WriteBufferSize > 0 {
writeBufSize = config.WriteBufferSize
}
readBufSize := defaultReadBufSize
if config.ReadBufferSize > 0 {
readBufSize = config.ReadBufferSize
}
framer := newFramer(conn, writeBufSize, readBufSize)
// Send initial settings as connection preface to client. // Send initial settings as connection preface to client.
var isettings []http2.Setting var isettings []http2.Setting
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is // TODO(zhaoq): Have a better way to signal "no limit" because 0 is
@ -151,12 +151,12 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
ID: http2.SettingInitialWindowSize, ID: http2.SettingInitialWindowSize,
Val: uint32(iwz)}) Val: uint32(iwz)})
} }
if err := framer.writeSettings(true, isettings...); err != nil { if err := framer.fr.WriteSettings(isettings...); err != nil {
return nil, connectionErrorf(true, err, "transport: %v", err) return nil, connectionErrorf(true, err, "transport: %v", err)
} }
// Adjust the connection flow control window if needed. // Adjust the connection flow control window if needed.
if delta := uint32(icwz - defaultWindowSize); delta > 0 { if delta := uint32(icwz - defaultWindowSize); delta > 0 {
if err := framer.writeWindowUpdate(true, 0, delta); err != nil { if err := framer.fr.WriteWindowUpdate(0, delta); err != nil {
return nil, connectionErrorf(true, err, "transport: %v", err) return nil, connectionErrorf(true, err, "transport: %v", err)
} }
} }
@ -183,8 +183,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
kep.MinTime = defaultKeepalivePolicyMinTime kep.MinTime = defaultKeepalivePolicyMinTime
} }
var buf bytes.Buffer var buf bytes.Buffer
ctx, cancel := context.WithCancel(context.Background())
t := &http2Server{ t := &http2Server{
ctx: context.Background(), ctx: ctx,
cancel: cancel,
conn: conn, conn: conn,
remoteAddr: conn.RemoteAddr(), remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(), localAddr: conn.LocalAddr(),
@ -198,8 +200,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
fc: &inFlow{limit: uint32(icwz)}, fc: &inFlow{limit: uint32(icwz)},
sendQuotaPool: newQuotaPool(defaultWindowSize), sendQuotaPool: newQuotaPool(defaultWindowSize),
state: reachable, state: reachable,
writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}),
activeStreams: make(map[uint32]*Stream), activeStreams: make(map[uint32]*Stream),
streamSendQuota: defaultWindowSize, streamSendQuota: defaultWindowSize,
stats: config.StatsHandler, stats: config.StatsHandler,
@ -222,9 +222,12 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
connBegin := &stats.ConnBegin{} connBegin := &stats.ConnBegin{}
t.stats.HandleConn(t.ctx, connBegin) t.stats.HandleConn(t.ctx, connBegin)
} }
go t.controller() t.framer.writer.Flush()
go func() {
loopyWriter(t.ctx, t.controlBuf, t.itemHandler)
t.Close()
}()
go t.keepalive() go t.keepalive()
t.writableChan <- 0
return t, nil return t, nil
} }
@ -313,6 +316,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
} }
t.maxStreamID = streamID t.maxStreamID = streamID
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
s.localSendQuota = newQuotaPool(defaultLocalSendQuota)
t.activeStreams[streamID] = s t.activeStreams[streamID] = s
if len(t.activeStreams) == 1 { if len(t.activeStreams) == 1 {
t.idle = time.Time{} t.idle = time.Time{}
@ -366,7 +370,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
return return
} }
frame, err := t.framer.readFrame() frame, err := t.framer.fr.ReadFrame()
if err == io.EOF || err == io.ErrUnexpectedEOF { if err == io.EOF || err == io.ErrUnexpectedEOF {
t.Close() t.Close()
return return
@ -386,7 +390,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
t.handleSettings(sf) t.handleSettings(sf)
for { for {
frame, err := t.framer.readFrame() frame, err := t.framer.fr.ReadFrame()
atomic.StoreUint32(&t.activity, 1) atomic.StoreUint32(&t.activity, 1)
if err != nil { if err != nil {
if se, ok := err.(http2.StreamError); ok { if se, ok := err.(http2.StreamError); ok {
@ -457,9 +461,9 @@ func (t *http2Server) adjustWindow(s *Stream, n uint32) {
} }
if w := s.fc.maybeAdjust(n); w > 0 { if w := s.fc.maybeAdjust(n); w > 0 {
if cw := t.fc.resetPendingUpdate(); cw > 0 { if cw := t.fc.resetPendingUpdate(); cw > 0 {
t.controlBuf.put(&windowUpdate{0, cw, false}) t.controlBuf.put(&windowUpdate{0, cw})
} }
t.controlBuf.put(&windowUpdate{s.id, w, true}) t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
@ -474,9 +478,9 @@ func (t *http2Server) updateWindow(s *Stream, n uint32) {
} }
if w := s.fc.onRead(n); w > 0 { if w := s.fc.onRead(n); w > 0 {
if cw := t.fc.resetPendingUpdate(); cw > 0 { if cw := t.fc.resetPendingUpdate(); cw > 0 {
t.controlBuf.put(&windowUpdate{0, cw, false}) t.controlBuf.put(&windowUpdate{0, cw})
} }
t.controlBuf.put(&windowUpdate{s.id, w, true}) t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
@ -490,7 +494,7 @@ func (t *http2Server) updateFlowControl(n uint32) {
} }
t.initialWindowSize = int32(n) t.initialWindowSize = int32(n)
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n), false}) t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n)})
t.controlBuf.put(&settings{ t.controlBuf.put(&settings{
ack: false, ack: false,
ss: []http2.Setting{ ss: []http2.Setting{
@ -521,7 +525,9 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// Furthermore, if a bdpPing is being sent out we can piggyback // Furthermore, if a bdpPing is being sent out we can piggyback
// connection's window update for the bytes we just received. // connection's window update for the bytes we just received.
if sendBDPPing { if sendBDPPing {
t.controlBuf.put(&windowUpdate{0, uint32(size), false}) if size != 0 { // Could be an empty frame.
t.controlBuf.put(&windowUpdate{0, uint32(size)})
}
t.controlBuf.put(bdpPing) t.controlBuf.put(bdpPing)
} else { } else {
if err := t.fc.onData(uint32(size)); err != nil { if err := t.fc.onData(uint32(size)); err != nil {
@ -530,7 +536,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
return return
} }
if w := t.fc.onRead(uint32(size)); w > 0 { if w := t.fc.onRead(uint32(size)); w > 0 {
t.controlBuf.put(&windowUpdate{0, w, true}) t.controlBuf.put(&windowUpdate{0, w})
} }
} }
// Select the right stream to dispatch. // Select the right stream to dispatch.
@ -552,7 +558,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
} }
if f.Header().Flags.Has(http2.FlagDataPadded) { if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 {
t.controlBuf.put(&windowUpdate{s.id, w, true}) t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
s.mu.Unlock() s.mu.Unlock()
@ -593,10 +599,23 @@ func (t *http2Server) handleSettings(f *http2.SettingsFrame) {
ss = append(ss, s) ss = append(ss, s)
return nil return nil
}) })
// The settings will be applied once the ack is sent.
t.controlBuf.put(&settings{ack: true, ss: ss}) t.controlBuf.put(&settings{ack: true, ss: ss})
} }
func (t *http2Server) applySettings(ss []http2.Setting) {
for _, s := range ss {
if s.ID == http2.SettingInitialWindowSize {
t.mu.Lock()
for _, stream := range t.activeStreams {
stream.sendQuotaPool.addAndUpdate(int(s.Val) - int(t.streamSendQuota))
}
t.streamSendQuota = s.Val
t.mu.Unlock()
}
}
}
const ( const (
maxPingStrikes = 2 maxPingStrikes = 2
defaultPingTimeout = 2 * time.Hour defaultPingTimeout = 2 * time.Hour
@ -634,7 +653,7 @@ func (t *http2Server) handlePing(f *http2.PingFrame) {
t.mu.Unlock() t.mu.Unlock()
if ns < 1 && !t.kep.PermitWithoutStream { if ns < 1 && !t.kep.PermitWithoutStream {
// Keepalive shouldn't be active thus, this new ping should // Keepalive shouldn't be active thus, this new ping should
// have come after atleast defaultPingTimeout. // have come after at least defaultPingTimeout.
if t.lastPingAt.Add(defaultPingTimeout).After(now) { if t.lastPingAt.Add(defaultPingTimeout).After(now) {
t.pingStrikes++ t.pingStrikes++
} }
@ -647,6 +666,7 @@ func (t *http2Server) handlePing(f *http2.PingFrame) {
if t.pingStrikes > maxPingStrikes { if t.pingStrikes > maxPingStrikes {
// Send goaway and close the connection. // Send goaway and close the connection.
errorf("transport: Got to too many pings from the client, closing the connection.")
t.controlBuf.put(&goAway{code: http2.ErrCodeEnhanceYourCalm, debugData: []byte("too_many_pings"), closeConn: true}) t.controlBuf.put(&goAway{code: http2.ErrCodeEnhanceYourCalm, debugData: []byte("too_many_pings"), closeConn: true})
} }
} }
@ -663,47 +683,16 @@ func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) {
} }
} }
func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) error {
first := true
endHeaders := false
var err error
defer func() {
if err == nil {
// Reset ping strikes when seding headers since that might cause the
// peer to send ping.
atomic.StoreUint32(&t.resetPingStrikes, 1)
}
}()
// Sends the headers in a single batch.
for !endHeaders {
size := t.hBuf.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
if first {
p := http2.HeadersFrameParam{
StreamID: s.id,
BlockFragment: b.Next(size),
EndStream: endStream,
EndHeaders: endHeaders,
}
err = t.framer.writeHeaders(endHeaders, p)
first = false
} else {
err = t.framer.writeContinuation(endHeaders, s.id, endHeaders, b.Next(size))
}
if err != nil {
t.Close()
return connectionErrorf(true, err, "transport: %v", err)
}
}
return nil
}
// WriteHeader sends the header metedata md back to the client. // WriteHeader sends the header metedata md back to the client.
func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
select {
case <-s.ctx.Done():
return ContextErr(s.ctx.Err())
case <-t.ctx.Done():
return ErrConnClosing
default:
}
s.mu.Lock() s.mu.Lock()
if s.headerOk || s.state == streamDone { if s.headerOk || s.state == streamDone {
s.mu.Unlock() s.mu.Unlock()
@ -719,14 +708,13 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
} }
md = s.header md = s.header
s.mu.Unlock() s.mu.Unlock()
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
return err // first and create a slice of that exact size.
} headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
t.hBuf.Reset() headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
if s.sendCompress != "" { if s.sendCompress != "" {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
} }
for k, vv := range md { for k, vv := range md {
if isReservedHeader(k) { if isReservedHeader(k) {
@ -734,20 +722,20 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
continue continue
} }
for _, v := range vv { for _, v := range vv {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
} }
} }
bufLen := t.hBuf.Len() t.controlBuf.put(&headerFrame{
if err := t.writeHeaders(s, t.hBuf, false); err != nil { streamID: s.id,
return err hf: headerFields,
} endStream: false,
})
if t.stats != nil { if t.stats != nil {
outHeader := &stats.OutHeader{ outHeader := &stats.OutHeader{
WireLength: bufLen, //WireLength: // TODO(mmukhi): Revisit this later, if needed.
} }
t.stats.HandleRPC(s.Context(), outHeader) t.stats.HandleRPC(s.Context(), outHeader)
} }
t.writableChan <- 0
return nil return nil
} }
@ -756,6 +744,12 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted. // OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
select {
case <-t.ctx.Done():
return ErrConnClosing
default:
}
var headersSent, hasHeader bool var headersSent, hasHeader bool
s.mu.Lock() s.mu.Lock()
if s.state == streamDone { if s.state == streamDone {
@ -775,25 +769,15 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
headersSent = true headersSent = true
} }
// Always write a status regardless of context cancellation unless the stream // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
// is terminated (e.g. by a RST_STREAM, GOAWAY, or transport error). The // first and create a slice of that exact size.
// server's application code is already done so it is fine to ignore s.ctx. headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else.
select {
case <-t.shutdownChan:
return ErrConnClosing
case <-t.writableChan:
}
t.hBuf.Reset()
if !headersSent { if !headersSent {
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
} }
t.hEnc.WriteField( headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
hpack.HeaderField{ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
Name: "grpc-status",
Value: strconv.Itoa(int(st.Code())),
})
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
if p := st.Proto(); p != nil && len(p.Details) > 0 { if p := st.Proto(); p != nil && len(p.Details) > 0 {
stBytes, err := proto.Marshal(p) stBytes, err := proto.Marshal(p)
@ -802,7 +786,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
panic(err) panic(err)
} }
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
} }
// Attach the trailer metadata. // Attach the trailer metadata.
@ -812,36 +796,32 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
continue continue
} }
for _, v := range vv { for _, v := range vv {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
} }
} }
bufLen := t.hBuf.Len() t.controlBuf.put(&headerFrame{
if err := t.writeHeaders(s, t.hBuf, true); err != nil { streamID: s.id,
t.Close() hf: headerFields,
return err endStream: true,
} })
if t.stats != nil { if t.stats != nil {
outTrailer := &stats.OutTrailer{ t.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
WireLength: bufLen,
}
t.stats.HandleRPC(s.Context(), outTrailer)
} }
t.closeStream(s) t.closeStream(s)
t.writableChan <- 0
return nil return nil
} }
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error // Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error). // is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (err error) { func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (err error) {
// TODO(zhaoq): Support multi-writers for a single stream. select {
secondStart := http2MaxFrameLen - len(hdr)%http2MaxFrameLen case <-s.ctx.Done():
if len(data) < secondStart { return ContextErr(s.ctx.Err())
secondStart = len(data) case <-t.ctx.Done():
return ErrConnClosing
default:
} }
hdr = append(hdr, data[:secondStart]...)
data = data[secondStart:]
isLastSlice := (len(data) == 0)
var writeHeaderFrame bool var writeHeaderFrame bool
s.mu.Lock() s.mu.Lock()
if s.state == streamDone { if s.state == streamDone {
@ -855,24 +835,24 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (
if writeHeaderFrame { if writeHeaderFrame {
t.WriteHeader(s, nil) t.WriteHeader(s, nil)
} }
r := bytes.NewBuffer(hdr) // Add data to header frame so that we can equally distribute data across frames.
var ( emptyLen := http2MaxFrameLen - len(hdr)
p []byte if emptyLen > len(data) {
oqv uint32 emptyLen = len(data)
)
for {
if r.Len() == 0 && p == nil {
return nil
} }
oqv = atomic.LoadUint32(&t.outQuotaVersion) hdr = append(hdr, data[:emptyLen]...)
data = data[emptyLen:]
for _, r := range [][]byte{hdr, data} {
for len(r) > 0 {
size := http2MaxFrameLen size := http2MaxFrameLen
// Wait until the stream has some quota to send the data. // Wait until the stream has some quota to send the data.
sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire()) quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion()
sq, err := wait(s.ctx, t.ctx, nil, nil, quotaChan)
if err != nil { if err != nil {
return err return err
} }
// Wait until the transport has some quota to send the data. // Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire()) tq, err := wait(s.ctx, t.ctx, nil, nil, t.sendQuotaPool.acquire())
if err != nil { if err != nil {
return err return err
} }
@ -882,97 +862,47 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (
if tq < size { if tq < size {
size = tq size = tq
} }
if p == nil { if size > len(r) {
p = r.Next(size) size = len(r)
} }
p := r[:size]
ps := len(p) ps := len(p)
if ps < sq {
// Overbooked stream quota. Return it back.
s.sendQuotaPool.add(sq - ps)
}
if ps < tq { if ps < tq {
// Overbooked transport quota. Return it back. // Overbooked transport quota. Return it back.
t.sendQuotaPool.add(tq - ps) t.sendQuotaPool.add(tq - ps)
} }
t.framer.adjustNumWriters(1) // Acquire local send quota to be able to write to the controlBuf.
// Got some quota. Try to acquire writing privilege on the ltq, err := wait(s.ctx, t.ctx, nil, nil, s.localSendQuota.acquire())
// transport. if err != nil {
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { if _, ok := err.(ConnectionError); !ok {
if _, ok := err.(StreamError); ok {
// Return the connection quota back.
t.sendQuotaPool.add(ps) t.sendQuotaPool.add(ps)
} }
if t.framer.adjustNumWriters(-1) == 0 {
// This writer is the last one in this batch and has the
// responsibility to flush the buffered frames. It queues
// a flush request to controlBuf instead of flushing directly
// in order to avoid the race with other writing or flushing.
t.controlBuf.put(&flushIO{})
}
return err return err
} }
select { s.localSendQuota.add(ltq - ps) // It's ok we make this negative.
case <-s.ctx.Done():
t.sendQuotaPool.add(ps)
if t.framer.adjustNumWriters(-1) == 0 {
t.controlBuf.put(&flushIO{})
}
t.writableChan <- 0
return ContextErr(s.ctx.Err())
default:
}
if oqv != atomic.LoadUint32(&t.outQuotaVersion) {
// InitialWindowSize settings frame must have been received after we
// acquired send quota but before we got the writable channel.
// We must forsake this write.
t.sendQuotaPool.add(ps)
s.sendQuotaPool.add(ps)
if t.framer.adjustNumWriters(-1) == 0 {
t.controlBuf.put(&flushIO{})
}
t.writableChan <- 0
continue
}
var forceFlush bool
if r.Len() == 0 {
if isLastSlice {
if t.framer.adjustNumWriters(0) == 1 && !opts.Last {
forceFlush = true
}
} else {
r = bytes.NewBuffer(data)
isLastSlice = true
}
}
// Reset ping strikes when sending data since this might cause // Reset ping strikes when sending data since this might cause
// the peer to send ping. // the peer to send ping.
atomic.StoreUint32(&t.resetPingStrikes, 1) atomic.StoreUint32(&t.resetPingStrikes, 1)
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { success := func() {
t.Close() t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() {
return connectionErrorf(true, err, "transport: %v", err) s.localSendQuota.add(ps)
}})
if ps < sq {
// Overbooked stream quota. Return it back.
s.sendQuotaPool.lockedAdd(sq - ps)
} }
p = nil r = r[ps:]
if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite()
} }
t.writableChan <- 0 failure := func() {
s.sendQuotaPool.lockedAdd(sq)
} }
if !s.sendQuotaPool.compareAndExecute(quotaVer, success, failure) {
} t.sendQuotaPool.add(ps)
s.localSendQuota.add(ps)
func (t *http2Server) applySettings(ss []http2.Setting) {
for _, s := range ss {
if s.ID == http2.SettingInitialWindowSize {
t.mu.Lock()
defer t.mu.Unlock()
for _, stream := range t.activeStreams {
stream.sendQuotaPool.add(int(s.Val) - int(t.streamSendQuota))
} }
t.streamSendQuota = s.Val
atomic.AddUint32(&t.outQuotaVersion, 1)
} }
} }
return nil
} }
// keepalive running in a separate goroutine does the following: // keepalive running in a separate goroutine does the following:
@ -988,7 +918,7 @@ func (t *http2Server) keepalive() {
maxAge := time.NewTimer(t.kp.MaxConnectionAge) maxAge := time.NewTimer(t.kp.MaxConnectionAge)
keepalive := time.NewTimer(t.kp.Time) keepalive := time.NewTimer(t.kp.Time)
// NOTE: All exit paths of this function should reset their // NOTE: All exit paths of this function should reset their
// respecitve timers. A failure to do so will cause the // respective timers. A failure to do so will cause the
// following clean-up to deadlock and eventually leak. // following clean-up to deadlock and eventually leak.
defer func() { defer func() {
if !maxIdle.Stop() { if !maxIdle.Stop() {
@ -1031,7 +961,7 @@ func (t *http2Server) keepalive() {
t.Close() t.Close()
// Reseting the timer so that the clean-up doesn't deadlock. // Reseting the timer so that the clean-up doesn't deadlock.
maxAge.Reset(infinity) maxAge.Reset(infinity)
case <-t.shutdownChan: case <-t.ctx.Done():
} }
return return
case <-keepalive.C: case <-keepalive.C:
@ -1049,7 +979,7 @@ func (t *http2Server) keepalive() {
pingSent = true pingSent = true
t.controlBuf.put(p) t.controlBuf.put(p)
keepalive.Reset(t.kp.Timeout) keepalive.Reset(t.kp.Timeout)
case <-t.shutdownChan: case <-t.ctx.Done():
return return
} }
} }
@ -1057,47 +987,85 @@ func (t *http2Server) keepalive() {
var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}}
// controller running in a separate goroutine takes charge of sending control // TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
// frames (e.g., window update, reset stream, setting, etc.) to the server. // is duplicated between the client and the server.
func (t *http2Server) controller() { // The transport layer needs to be refactored to take care of this.
for { func (t *http2Server) itemHandler(i item) error {
select {
case i := <-t.controlBuf.get():
t.controlBuf.load()
select {
case <-t.writableChan:
switch i := i.(type) { switch i := i.(type) {
case *dataFrame:
if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil {
return err
}
i.f()
return nil
case *headerFrame:
t.hBuf.Reset()
for _, f := range i.hf {
t.hEnc.WriteField(f)
}
first := true
endHeaders := false
for !endHeaders {
size := t.hBuf.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
var err error
if first {
first = false
err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: i.streamID,
BlockFragment: t.hBuf.Next(size),
EndStream: i.endStream,
EndHeaders: endHeaders,
})
} else {
err = t.framer.fr.WriteContinuation(
i.streamID,
endHeaders,
t.hBuf.Next(size),
)
}
if err != nil {
return err
}
}
atomic.StoreUint32(&t.resetPingStrikes, 1)
return nil
case *windowUpdate: case *windowUpdate:
t.framer.writeWindowUpdate(i.flush, i.streamID, i.increment) return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
case *settings: case *settings:
if i.ack { if i.ack {
t.framer.writeSettingsAck(true)
t.applySettings(i.ss) t.applySettings(i.ss)
} else { return t.framer.fr.WriteSettingsAck()
t.framer.writeSettings(true, i.ss...)
} }
return t.framer.fr.WriteSettings(i.ss...)
case *resetStream: case *resetStream:
t.framer.writeRSTStream(true, i.streamID, i.code) return t.framer.fr.WriteRSTStream(i.streamID, i.code)
case *goAway: case *goAway:
t.mu.Lock() t.mu.Lock()
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
// The transport is closing. // The transport is closing.
return return fmt.Errorf("transport: Connection closing")
} }
sid := t.maxStreamID sid := t.maxStreamID
if !i.headsUp { if !i.headsUp {
// Stop accepting more streams now. // Stop accepting more streams now.
t.state = draining t.state = draining
activeStreams := len(t.activeStreams)
t.mu.Unlock() t.mu.Unlock()
t.framer.writeGoAway(true, sid, i.code, i.debugData) if err := t.framer.fr.WriteGoAway(sid, i.code, i.debugData); err != nil {
if i.closeConn || activeStreams == 0 { return err
// Abruptly close the connection following the GoAway.
t.Close()
} }
t.writableChan <- 0 if i.closeConn {
continue // Abruptly close the connection following the GoAway (via
// loopywriter). But flush out what's inside the buffer first.
t.framer.writer.Flush()
return fmt.Errorf("transport: Connection closing")
}
return nil
} }
t.mu.Unlock() t.mu.Unlock()
// For a graceful close, send out a GoAway with stream ID of MaxUInt32, // For a graceful close, send out a GoAway with stream ID of MaxUInt32,
@ -1106,44 +1074,42 @@ func (t *http2Server) controller() {
// originated before the GoAway reaches the client. // originated before the GoAway reaches the client.
// After getting the ack or timer expiration send out another GoAway this // After getting the ack or timer expiration send out another GoAway this
// time with an ID of the max stream server intends to process. // time with an ID of the max stream server intends to process.
t.framer.writeGoAway(true, math.MaxUint32, http2.ErrCodeNo, []byte{}) if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil {
t.framer.writePing(true, false, goAwayPing.data) return err
}
if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil {
return err
}
go func() { go func() {
timer := time.NewTimer(time.Minute) timer := time.NewTimer(time.Minute)
defer timer.Stop() defer timer.Stop()
select { select {
case <-t.drainChan: case <-t.drainChan:
case <-timer.C: case <-timer.C:
case <-t.shutdownChan: case <-t.ctx.Done():
return return
} }
t.controlBuf.put(&goAway{code: i.code, debugData: i.debugData}) t.controlBuf.put(&goAway{code: i.code, debugData: i.debugData})
}() }()
return nil
case *flushIO: case *flushIO:
t.framer.flushWrite() return t.framer.writer.Flush()
case *ping: case *ping:
if !i.ack { if !i.ack {
t.bdpEst.timesnap(i.data) t.bdpEst.timesnap(i.data)
} }
t.framer.writePing(true, i.ack, i.data) return t.framer.fr.WritePing(i.ack, i.data)
default: default:
errorf("transport: http2Server.controller got unexpected item type %v\n", i) err := status.Errorf(codes.Internal, "transport: http2Server.controller got unexpected item type %t", i)
} errorf("%v", err)
t.writableChan <- 0 return err
continue
case <-t.shutdownChan:
return
}
case <-t.shutdownChan:
return
}
} }
} }
// Close starts shutting down the http2Server transport. // Close starts shutting down the http2Server transport.
// TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This
// could cause some resource issue. Revisit this later. // could cause some resource issue. Revisit this later.
func (t *http2Server) Close() (err error) { func (t *http2Server) Close() error {
t.mu.Lock() t.mu.Lock()
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
@ -1153,8 +1119,8 @@ func (t *http2Server) Close() (err error) {
streams := t.activeStreams streams := t.activeStreams
t.activeStreams = nil t.activeStreams = nil
t.mu.Unlock() t.mu.Unlock()
close(t.shutdownChan) t.cancel()
err = t.conn.Close() err := t.conn.Close()
// Cancel all active streams. // Cancel all active streams.
for _, s := range streams { for _, s := range streams {
s.cancel() s.cancel()
@ -1163,7 +1129,7 @@ func (t *http2Server) Close() (err error) {
connEnd := &stats.ConnEnd{} connEnd := &stats.ConnEnd{}
t.stats.HandleConn(t.ctx, connEnd) t.stats.HandleConn(t.ctx, connEnd)
} }
return return err
} }
// closeStream clears the footprint of a stream when the stream is not needed // closeStream clears the footprint of a stream when the stream is not needed

View File

@ -28,7 +28,6 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -45,7 +44,8 @@ const (
// http://http2.github.io/http2-spec/#SettingValues // http://http2.github.io/http2-spec/#SettingValues
http2InitHeaderTableSize = 4096 http2InitHeaderTableSize = 4096
// http2IOBufSize specifies the buffer size for sending frames. // http2IOBufSize specifies the buffer size for sending frames.
http2IOBufSize = 32 * 1024 defaultWriteBufSize = 32 * 1024
defaultReadBufSize = 32 * 1024
) )
var ( var (
@ -475,10 +475,10 @@ type framer struct {
fr *http2.Framer fr *http2.Framer
} }
func newFramer(conn net.Conn) *framer { func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer {
f := &framer{ f := &framer{
reader: bufio.NewReaderSize(conn, http2IOBufSize), reader: bufio.NewReaderSize(conn, readBufferSize),
writer: bufio.NewWriterSize(conn, http2IOBufSize), writer: bufio.NewWriterSize(conn, writeBufferSize),
} }
f.fr = http2.NewFramer(f.writer, f.reader) f.fr = http2.NewFramer(f.writer, f.reader)
// Opt-in to Frame reuse API on framer to reduce garbage. // Opt-in to Frame reuse API on framer to reduce garbage.
@ -487,132 +487,3 @@ func newFramer(conn net.Conn) *framer {
f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil) f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
return f return f
} }
func (f *framer) adjustNumWriters(i int32) int32 {
return atomic.AddInt32(&f.numWriters, i)
}
// The following writeXXX functions can only be called when the caller gets
// unblocked from writableChan channel (i.e., owns the privilege to write).
func (f *framer) writeContinuation(forceFlush bool, streamID uint32, endHeaders bool, headerBlockFragment []byte) error {
if err := f.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) writeData(forceFlush bool, streamID uint32, endStream bool, data []byte) error {
if err := f.fr.WriteData(streamID, endStream, data); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) writeGoAway(forceFlush bool, maxStreamID uint32, code http2.ErrCode, debugData []byte) error {
if err := f.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) writeHeaders(forceFlush bool, p http2.HeadersFrameParam) error {
if err := f.fr.WriteHeaders(p); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) writePing(forceFlush, ack bool, data [8]byte) error {
if err := f.fr.WritePing(ack, data); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) writePriority(forceFlush bool, streamID uint32, p http2.PriorityParam) error {
if err := f.fr.WritePriority(streamID, p); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) writePushPromise(forceFlush bool, p http2.PushPromiseParam) error {
if err := f.fr.WritePushPromise(p); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) writeRSTStream(forceFlush bool, streamID uint32, code http2.ErrCode) error {
if err := f.fr.WriteRSTStream(streamID, code); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) writeSettings(forceFlush bool, settings ...http2.Setting) error {
if err := f.fr.WriteSettings(settings...); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) writeSettingsAck(forceFlush bool) error {
if err := f.fr.WriteSettingsAck(); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) writeWindowUpdate(forceFlush bool, streamID, incr uint32) error {
if err := f.fr.WriteWindowUpdate(streamID, incr); err != nil {
return err
}
if forceFlush {
return f.writer.Flush()
}
return nil
}
func (f *framer) flushWrite() error {
return f.writer.Flush()
}
func (f *framer) readFrame() (http2.Frame, error) {
return f.fr.ReadFrame()
}
func (f *framer) errorDetail() error {
return f.fr.ErrorDetail()
}

View File

@ -21,10 +21,12 @@
package transport // import "google.golang.org/grpc/transport" package transport // import "google.golang.org/grpc/transport"
import ( import (
stdctx "context"
"fmt" "fmt"
"io" "io"
"net" "net"
"sync" "sync"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/http2" "golang.org/x/net/http2"
@ -67,20 +69,20 @@ func newRecvBuffer() *recvBuffer {
func (b *recvBuffer) put(r recvMsg) { func (b *recvBuffer) put(r recvMsg) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock()
if len(b.backlog) == 0 { if len(b.backlog) == 0 {
select { select {
case b.c <- r: case b.c <- r:
b.mu.Unlock()
return return
default: default:
} }
} }
b.backlog = append(b.backlog, r) b.backlog = append(b.backlog, r)
b.mu.Unlock()
} }
func (b *recvBuffer) load() { func (b *recvBuffer) load() {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock()
if len(b.backlog) > 0 { if len(b.backlog) > 0 {
select { select {
case b.c <- b.backlog[0]: case b.c <- b.backlog[0]:
@ -89,6 +91,7 @@ func (b *recvBuffer) load() {
default: default:
} }
} }
b.mu.Unlock()
} }
// get returns the channel that receives a recvMsg in the buffer. // get returns the channel that receives a recvMsg in the buffer.
@ -164,20 +167,20 @@ func newControlBuffer() *controlBuffer {
func (b *controlBuffer) put(r item) { func (b *controlBuffer) put(r item) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock()
if len(b.backlog) == 0 { if len(b.backlog) == 0 {
select { select {
case b.c <- r: case b.c <- r:
b.mu.Unlock()
return return
default: default:
} }
} }
b.backlog = append(b.backlog, r) b.backlog = append(b.backlog, r)
b.mu.Unlock()
} }
func (b *controlBuffer) load() { func (b *controlBuffer) load() {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock()
if len(b.backlog) > 0 { if len(b.backlog) > 0 {
select { select {
case b.c <- b.backlog[0]: case b.c <- b.backlog[0]:
@ -186,6 +189,7 @@ func (b *controlBuffer) load() {
default: default:
} }
} }
b.mu.Unlock()
} }
// get returns the channel that receives an item in the buffer. // get returns the channel that receives an item in the buffer.
@ -236,6 +240,7 @@ type Stream struct {
requestRead func(int) requestRead func(int)
sendQuotaPool *quotaPool sendQuotaPool *quotaPool
localSendQuota *quotaPool
// Close headerChan to indicate the end of reception of header metadata. // Close headerChan to indicate the end of reception of header metadata.
headerChan chan struct{} headerChan chan struct{}
// header caches the received header metadata. // header caches the received header metadata.
@ -313,8 +318,9 @@ func (s *Stream) Header() (metadata.MD, error) {
// side only. // side only.
func (s *Stream) Trailer() metadata.MD { func (s *Stream) Trailer() metadata.MD {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() c := s.trailer.Copy()
return s.trailer.Copy() s.mu.RUnlock()
return c
} }
// ServerTransport returns the underlying ServerTransport for the stream. // ServerTransport returns the underlying ServerTransport for the stream.
@ -342,14 +348,16 @@ func (s *Stream) Status() *status.Status {
// Server side only. // Server side only.
func (s *Stream) SetHeader(md metadata.MD) error { func (s *Stream) SetHeader(md metadata.MD) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock()
if s.headerOk || s.state == streamDone { if s.headerOk || s.state == streamDone {
s.mu.Unlock()
return ErrIllegalHeaderWrite return ErrIllegalHeaderWrite
} }
if md.Len() == 0 { if md.Len() == 0 {
s.mu.Unlock()
return nil return nil
} }
s.header = metadata.Join(s.header, md) s.header = metadata.Join(s.header, md)
s.mu.Unlock()
return nil return nil
} }
@ -360,8 +368,8 @@ func (s *Stream) SetTrailer(md metadata.MD) error {
return nil return nil
} }
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock()
s.trailer = metadata.Join(s.trailer, md) s.trailer = metadata.Join(s.trailer, md)
s.mu.Unlock()
return nil return nil
} }
@ -412,15 +420,17 @@ func (s *Stream) finish(st *status.Status) {
// BytesSent indicates whether any bytes have been sent on this stream. // BytesSent indicates whether any bytes have been sent on this stream.
func (s *Stream) BytesSent() bool { func (s *Stream) BytesSent() bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() bs := s.bytesSent
return s.bytesSent s.mu.Unlock()
return bs
} }
// BytesReceived indicates whether any bytes have been received on this stream. // BytesReceived indicates whether any bytes have been received on this stream.
func (s *Stream) BytesReceived() bool { func (s *Stream) BytesReceived() bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() br := s.bytesReceived
return s.bytesReceived s.mu.Unlock()
return br
} }
// GoString is implemented by Stream so context.String() won't // GoString is implemented by Stream so context.String() won't
@ -449,7 +459,6 @@ type transportState int
const ( const (
reachable transportState = iota reachable transportState = iota
unreachable
closing closing
draining draining
) )
@ -464,6 +473,8 @@ type ServerConfig struct {
KeepalivePolicy keepalive.EnforcementPolicy KeepalivePolicy keepalive.EnforcementPolicy
InitialWindowSize int32 InitialWindowSize int32
InitialConnWindowSize int32 InitialConnWindowSize int32
WriteBufferSize int
ReadBufferSize int
} }
// NewServerTransport creates a ServerTransport with conn or non-nil error // NewServerTransport creates a ServerTransport with conn or non-nil error
@ -491,10 +502,14 @@ type ConnectOptions struct {
KeepaliveParams keepalive.ClientParameters KeepaliveParams keepalive.ClientParameters
// StatsHandler stores the handler for stats. // StatsHandler stores the handler for stats.
StatsHandler stats.Handler StatsHandler stats.Handler
// InitialWindowSize sets the intial window size for a stream. // InitialWindowSize sets the initial window size for a stream.
InitialWindowSize int32 InitialWindowSize int32
// InitialConnWindowSize sets the intial window size for a connection. // InitialConnWindowSize sets the initial window size for a connection.
InitialConnWindowSize int32 InitialConnWindowSize int32
// WriteBufferSize sets the size of write buffer which in turn determines how much data can be batched before it's written on the wire.
WriteBufferSize int
// ReadBufferSize sets the size of read buffer, which in turn determines how much data can be read at most for one read syscall.
ReadBufferSize int
} }
// TargetInfo contains the information of the target such as network address and metadata. // TargetInfo contains the information of the target such as network address and metadata.
@ -505,8 +520,8 @@ type TargetInfo struct {
// NewClientTransport establishes the transport with the required ConnectOptions // NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller. // and returns it to the caller.
func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions) (ClientTransport, error) { func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions, timeout time.Duration) (ClientTransport, error) {
return newHTTP2Client(ctx, target, opts) return newHTTP2Client(ctx, target, opts, timeout)
} }
// Options provides additional hints and information for message // Options provides additional hints and information for message
@ -518,7 +533,7 @@ type Options struct {
// Delay is a hint to the transport implementation for whether // Delay is a hint to the transport implementation for whether
// the data could be buffered for a batching write. The // the data could be buffered for a batching write. The
// Transport implementation may ignore the hint. // transport implementation may ignore the hint.
Delay bool Delay bool
} }
@ -688,34 +703,33 @@ func (e StreamError) Error() string {
return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc) return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc)
} }
// wait blocks until it can receive from ctx.Done, closing, or proceed. // wait blocks until it can receive from one of the provided contexts or channels
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. func wait(ctx, tctx context.Context, done, goAway <-chan struct{}, proceed <-chan int) (int, error) {
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
// it return the StreamError for ctx.Err.
// If it receives from goAway, it returns 0, ErrStreamDrain.
// If it receives from closing, it returns 0, ErrConnClosing.
// If it receives from proceed, it returns the received integer, nil.
func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return 0, ContextErr(ctx.Err()) return 0, ContextErr(ctx.Err())
case <-done: case <-done:
// User cancellation has precedence.
select {
case <-ctx.Done():
return 0, ContextErr(ctx.Err())
default:
}
return 0, io.EOF return 0, io.EOF
case <-goAway: case <-goAway:
return 0, ErrStreamDrain return 0, ErrStreamDrain
case <-closing: case <-tctx.Done():
return 0, ErrConnClosing return 0, ErrConnClosing
case i := <-proceed: case i := <-proceed:
return i, nil return i, nil
} }
} }
// ContextErr converts the error from context package into a StreamError.
func ContextErr(err error) StreamError {
switch err {
case context.DeadlineExceeded, stdctx.DeadlineExceeded:
return streamErrorf(codes.DeadlineExceeded, "%v", err)
case context.Canceled, stdctx.Canceled:
return streamErrorf(codes.Canceled, "%v", err)
}
return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err)
}
// GoAwayReason contains the reason for the GoAway frame received. // GoAwayReason contains the reason for the GoAway frame received.
type GoAwayReason uint8 type GoAwayReason uint8
@ -725,6 +739,39 @@ const (
// NoReason is the default value when GoAway frame is received. // NoReason is the default value when GoAway frame is received.
NoReason GoAwayReason = 1 NoReason GoAwayReason = 1
// TooManyPings indicates that a GoAway frame with ErrCodeEnhanceYourCalm // TooManyPings indicates that a GoAway frame with ErrCodeEnhanceYourCalm
// was recieved and that the debug data said "too_many_pings". // was received and that the debug data said "too_many_pings".
TooManyPings GoAwayReason = 2 TooManyPings GoAwayReason = 2
) )
// loopyWriter is run in a separate go routine. It is the single code path that will
// write data on wire.
func loopyWriter(ctx context.Context, cbuf *controlBuffer, handler func(item) error) {
for {
select {
case i := <-cbuf.get():
cbuf.load()
if err := handler(i); err != nil {
return
}
case <-ctx.Done():
return
}
hasData:
for {
select {
case i := <-cbuf.get():
cbuf.load()
if err := handler(i); err != nil {
return
}
case <-ctx.Done():
return
default:
if err := handler(&flushIO{}); err != nil {
return
}
break hasData
}
}
}
}

8
glide.lock generated
View File

@ -1,5 +1,5 @@
hash: cff74aae5a6b8c11816c9994dedfdfdcd9f4137d61d8ed8ba0bf623f0ff21d50 hash: 8d556efcf8d917aba74445bd98eb9de3abcf981ca70adb53276026c4c3ecfcfd
updated: 2017-11-10T09:46:28.3753-08:00 updated: 2017-11-10T12:07:01.187305-08:00
imports: imports:
- name: github.com/beorn7/perks - name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
@ -158,8 +158,9 @@ imports:
subpackages: subpackages:
- googleapis/rpc/status - googleapis/rpc/status
- name: google.golang.org/grpc - name: google.golang.org/grpc
version: f92cdcd7dcdc69e81b2d7b338479a19a8723cfa3 version: 5ffe3083946d5603a0578721101dc8165b1d5b5f
subpackages: subpackages:
- balancer
- codes - codes
- connectivity - connectivity
- credentials - credentials
@ -172,6 +173,7 @@ imports:
- metadata - metadata
- naming - naming
- peer - peer
- resolver
- stats - stats
- status - status
- tap - tap

View File

@ -108,7 +108,7 @@ import:
subpackages: subpackages:
- rate - rate
- package: google.golang.org/grpc - package: google.golang.org/grpc
version: v1.6.0 version: v1.7.2
subpackages: subpackages:
- codes - codes
- credentials - credentials