// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package grpcproxy

import (
	"io"
	"sync"
	"sync/atomic"
	"time"

	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"

	"golang.org/x/net/context"

	"github.com/coreos/etcd/clientv3"
	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
)

type leaseProxy struct {
	// leaseClient handles req from LeaseGrant() that requires a lease ID.
	leaseClient pb.LeaseClient

	lessor clientv3.Lease

	ctx context.Context

	leader *leader

	// mu protects adding outstanding leaseProxyStream through wg.
	mu sync.RWMutex

	// wg waits until all outstanding leaseProxyStream quit.
	wg sync.WaitGroup
}

func NewLeaseProxy(c *clientv3.Client) (pb.LeaseServer, <-chan struct{}) {
	cctx, cancel := context.WithCancel(c.Ctx())
	lp := &leaseProxy{
		leaseClient: pb.NewLeaseClient(c.ActiveConnection()),
		lessor:      c.Lease,
		ctx:         cctx,
		leader:      newLeader(c.Ctx(), c.Watcher),
	}
	ch := make(chan struct{})
	go func() {
		defer close(ch)
		<-lp.leader.stopNotify()
		lp.mu.Lock()
		select {
		case <-lp.ctx.Done():
		case <-lp.leader.disconnectNotify():
			cancel()
		}
		<-lp.ctx.Done()
		lp.mu.Unlock()
		lp.wg.Wait()
	}()
	return lp, ch
}

func (lp *leaseProxy) LeaseGrant(ctx context.Context, cr *pb.LeaseGrantRequest) (*pb.LeaseGrantResponse, error) {
	rp, err := lp.leaseClient.LeaseGrant(ctx, cr)
	if err != nil {
		return nil, err
	}
	lp.leader.gotLeader()
	return rp, nil
}

func (lp *leaseProxy) LeaseRevoke(ctx context.Context, rr *pb.LeaseRevokeRequest) (*pb.LeaseRevokeResponse, error) {
	r, err := lp.lessor.Revoke(ctx, clientv3.LeaseID(rr.ID))
	if err != nil {
		return nil, err
	}
	lp.leader.gotLeader()
	return (*pb.LeaseRevokeResponse)(r), nil
}

func (lp *leaseProxy) LeaseTimeToLive(ctx context.Context, rr *pb.LeaseTimeToLiveRequest) (*pb.LeaseTimeToLiveResponse, error) {
	var (
		r   *clientv3.LeaseTimeToLiveResponse
		err error
	)
	if rr.Keys {
		r, err = lp.lessor.TimeToLive(ctx, clientv3.LeaseID(rr.ID), clientv3.WithAttachedKeys())
	} else {
		r, err = lp.lessor.TimeToLive(ctx, clientv3.LeaseID(rr.ID))
	}
	if err != nil {
		return nil, err
	}
	rp := &pb.LeaseTimeToLiveResponse{
		Header:     r.ResponseHeader,
		ID:         int64(r.ID),
		TTL:        r.TTL,
		GrantedTTL: r.GrantedTTL,
		Keys:       r.Keys,
	}
	return rp, err
}

func (lp *leaseProxy) LeaseKeepAlive(stream pb.Lease_LeaseKeepAliveServer) error {
	lp.mu.Lock()
	select {
	case <-lp.ctx.Done():
		lp.mu.Unlock()
		return lp.ctx.Err()
	default:
		lp.wg.Add(1)
	}
	lp.mu.Unlock()

	ctx, cancel := context.WithCancel(stream.Context())
	lps := leaseProxyStream{
		stream:          stream,
		lessor:          lp.lessor,
		keepAliveLeases: make(map[int64]*atomicCounter),
		respc:           make(chan *pb.LeaseKeepAliveResponse),
		ctx:             ctx,
		cancel:          cancel,
	}

	errc := make(chan error, 2)

	var lostLeaderC <-chan struct{}
	if md, ok := metadata.FromOutgoingContext(stream.Context()); ok {
		v := md[rpctypes.MetadataRequireLeaderKey]
		if len(v) > 0 && v[0] == rpctypes.MetadataHasLeader {
			lostLeaderC = lp.leader.lostNotify()
			// if leader is known to be lost at creation time, avoid
			// letting events through at all
			select {
			case <-lostLeaderC:
				lp.wg.Done()
				return rpctypes.ErrNoLeader
			default:
			}
		}
	}
	stopc := make(chan struct{}, 3)
	go func() {
		defer func() { stopc <- struct{}{} }()
		if err := lps.recvLoop(); err != nil {
			errc <- err
		}
	}()

	go func() {
		defer func() { stopc <- struct{}{} }()
		if err := lps.sendLoop(); err != nil {
			errc <- err
		}
	}()

	// tears down LeaseKeepAlive stream if leader goes down or entire leaseProxy is terminated.
	go func() {
		defer func() { stopc <- struct{}{} }()
		select {
		case <-lostLeaderC:
		case <-ctx.Done():
		case <-lp.ctx.Done():
		}
	}()

	var err error
	select {
	case <-stopc:
		stopc <- struct{}{}
	case err = <-errc:
	}
	cancel()

	// recv/send may only shutdown after function exits;
	// this goroutine notifies lease proxy that the stream is through
	go func() {
		<-stopc
		<-stopc
		<-stopc
		lps.close()
		close(errc)
		lp.wg.Done()
	}()

	select {
	case <-lostLeaderC:
		return rpctypes.ErrNoLeader
	case <-lp.leader.disconnectNotify():
		return grpc.ErrClientConnClosing
	default:
		if err != nil {
			return err
		}
		return ctx.Err()
	}
}

type leaseProxyStream struct {
	stream pb.Lease_LeaseKeepAliveServer

	lessor clientv3.Lease
	// wg tracks keepAliveLoop goroutines
	wg sync.WaitGroup
	// mu protects keepAliveLeases
	mu sync.RWMutex
	// keepAliveLeases tracks how many outstanding keepalive requests which need responses are on a lease.
	keepAliveLeases map[int64]*atomicCounter
	// respc receives lease keepalive responses from etcd backend
	respc chan *pb.LeaseKeepAliveResponse

	ctx    context.Context
	cancel context.CancelFunc
}

func (lps *leaseProxyStream) recvLoop() error {
	for {
		rr, err := lps.stream.Recv()
		if err == io.EOF {
			return nil
		}
		if err != nil {
			return err
		}
		lps.mu.Lock()
		neededResps, ok := lps.keepAliveLeases[rr.ID]
		if !ok {
			neededResps = &atomicCounter{}
			lps.keepAliveLeases[rr.ID] = neededResps
			lps.wg.Add(1)
			go func() {
				defer lps.wg.Done()
				if err := lps.keepAliveLoop(rr.ID, neededResps); err != nil {
					lps.cancel()
				}
			}()
		}
		neededResps.add(1)
		lps.mu.Unlock()
	}
}

func (lps *leaseProxyStream) keepAliveLoop(leaseID int64, neededResps *atomicCounter) error {
	cctx, ccancel := context.WithCancel(lps.ctx)
	defer ccancel()
	respc, err := lps.lessor.KeepAlive(cctx, clientv3.LeaseID(leaseID))
	if err != nil {
		return err
	}
	// ticker expires when loop hasn't received keepalive within TTL
	var ticker <-chan time.Time
	for {
		select {
		case <-ticker:
			lps.mu.Lock()
			// if there are outstanding keepAlive reqs at the moment of ticker firing,
			// don't close keepAliveLoop(), let it continuing to process the KeepAlive reqs.
			if neededResps.get() > 0 {
				lps.mu.Unlock()
				ticker = nil
				continue
			}
			delete(lps.keepAliveLeases, leaseID)
			lps.mu.Unlock()
			return nil
		case rp, ok := <-respc:
			if !ok {
				lps.mu.Lock()
				delete(lps.keepAliveLeases, leaseID)
				lps.mu.Unlock()
				if neededResps.get() == 0 {
					return nil
				}
				ttlResp, err := lps.lessor.TimeToLive(cctx, clientv3.LeaseID(leaseID))
				if err != nil {
					return err
				}
				r := &pb.LeaseKeepAliveResponse{
					Header: ttlResp.ResponseHeader,
					ID:     int64(ttlResp.ID),
					TTL:    ttlResp.TTL,
				}
				for neededResps.get() > 0 {
					select {
					case lps.respc <- r:
						neededResps.add(-1)
					case <-lps.ctx.Done():
						return nil
					}
				}
				return nil
			}
			if neededResps.get() == 0 {
				continue
			}
			ticker = time.After(time.Duration(rp.TTL) * time.Second)
			r := &pb.LeaseKeepAliveResponse{
				Header: rp.ResponseHeader,
				ID:     int64(rp.ID),
				TTL:    rp.TTL,
			}
			lps.replyToClient(r, neededResps)
		}
	}
}

func (lps *leaseProxyStream) replyToClient(r *pb.LeaseKeepAliveResponse, neededResps *atomicCounter) {
	timer := time.After(500 * time.Millisecond)
	for neededResps.get() > 0 {
		select {
		case lps.respc <- r:
			neededResps.add(-1)
		case <-timer:
			return
		case <-lps.ctx.Done():
			return
		}
	}
}

func (lps *leaseProxyStream) sendLoop() error {
	for {
		select {
		case lrp, ok := <-lps.respc:
			if !ok {
				return nil
			}
			if err := lps.stream.Send(lrp); err != nil {
				return err
			}
		case <-lps.ctx.Done():
			return lps.ctx.Err()
		}
	}
}

func (lps *leaseProxyStream) close() {
	lps.cancel()
	lps.wg.Wait()
	// only close respc channel if all the keepAliveLoop() goroutines have finished
	// this ensures those goroutines don't send resp to a closed resp channel
	close(lps.respc)
}

type atomicCounter struct {
	counter int64
}

func (ac *atomicCounter) add(delta int64) {
	atomic.AddInt64(&ac.counter, delta)
}

func (ac *atomicCounter) get() int64 {
	return atomic.LoadInt64(&ac.counter)
}