vyvanse/vendor/github.com/Shopify/sarama/broker.go

686 lines
18 KiB
Go

package sarama
import (
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
)
// Broker represents a single Kafka broker connection. All operations on this object are entirely concurrency-safe.
type Broker struct {
id int32
addr string
conf *Config
correlationID int32
conn net.Conn
connErr error
lock sync.Mutex
opened int32
responses chan responsePromise
done chan bool
incomingByteRate metrics.Meter
requestRate metrics.Meter
requestSize metrics.Histogram
requestLatency metrics.Histogram
outgoingByteRate metrics.Meter
responseRate metrics.Meter
responseSize metrics.Histogram
brokerIncomingByteRate metrics.Meter
brokerRequestRate metrics.Meter
brokerRequestSize metrics.Histogram
brokerRequestLatency metrics.Histogram
brokerOutgoingByteRate metrics.Meter
brokerResponseRate metrics.Meter
brokerResponseSize metrics.Histogram
}
type responsePromise struct {
requestTime time.Time
correlationID int32
packets chan []byte
errors chan error
}
// NewBroker creates and returns a Broker targeting the given host:port address.
// This does not attempt to actually connect, you have to call Open() for that.
func NewBroker(addr string) *Broker {
return &Broker{id: -1, addr: addr}
}
// Open tries to connect to the Broker if it is not already connected or connecting, but does not block
// waiting for the connection to complete. This means that any subsequent operations on the broker will
// block waiting for the connection to succeed or fail. To get the effect of a fully synchronous Open call,
// follow it by a call to Connected(). The only errors Open will return directly are ConfigurationError or
// AlreadyConnected. If conf is nil, the result of NewConfig() is used.
func (b *Broker) Open(conf *Config) error {
if !atomic.CompareAndSwapInt32(&b.opened, 0, 1) {
return ErrAlreadyConnected
}
if conf == nil {
conf = NewConfig()
}
err := conf.Validate()
if err != nil {
return err
}
b.lock.Lock()
go withRecover(func() {
defer b.lock.Unlock()
dialer := net.Dialer{
Timeout: conf.Net.DialTimeout,
KeepAlive: conf.Net.KeepAlive,
}
if conf.Net.TLS.Enable {
b.conn, b.connErr = tls.DialWithDialer(&dialer, "tcp", b.addr, conf.Net.TLS.Config)
} else {
b.conn, b.connErr = dialer.Dial("tcp", b.addr)
}
if b.connErr != nil {
Logger.Printf("Failed to connect to broker %s: %s\n", b.addr, b.connErr)
b.conn = nil
atomic.StoreInt32(&b.opened, 0)
return
}
b.conn = newBufConn(b.conn)
b.conf = conf
// Create or reuse the global metrics shared between brokers
b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", conf.MetricRegistry)
b.requestRate = metrics.GetOrRegisterMeter("request-rate", conf.MetricRegistry)
b.requestSize = getOrRegisterHistogram("request-size", conf.MetricRegistry)
b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", conf.MetricRegistry)
b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", conf.MetricRegistry)
b.responseRate = metrics.GetOrRegisterMeter("response-rate", conf.MetricRegistry)
b.responseSize = getOrRegisterHistogram("response-size", conf.MetricRegistry)
// Do not gather metrics for seeded broker (only used during bootstrap) because they share
// the same id (-1) and are already exposed through the global metrics above
if b.id >= 0 {
b.brokerIncomingByteRate = getOrRegisterBrokerMeter("incoming-byte-rate", b, conf.MetricRegistry)
b.brokerRequestRate = getOrRegisterBrokerMeter("request-rate", b, conf.MetricRegistry)
b.brokerRequestSize = getOrRegisterBrokerHistogram("request-size", b, conf.MetricRegistry)
b.brokerRequestLatency = getOrRegisterBrokerHistogram("request-latency-in-ms", b, conf.MetricRegistry)
b.brokerOutgoingByteRate = getOrRegisterBrokerMeter("outgoing-byte-rate", b, conf.MetricRegistry)
b.brokerResponseRate = getOrRegisterBrokerMeter("response-rate", b, conf.MetricRegistry)
b.brokerResponseSize = getOrRegisterBrokerHistogram("response-size", b, conf.MetricRegistry)
}
if conf.Net.SASL.Enable {
b.connErr = b.sendAndReceiveSASLPlainAuth()
if b.connErr != nil {
err = b.conn.Close()
if err == nil {
Logger.Printf("Closed connection to broker %s\n", b.addr)
} else {
Logger.Printf("Error while closing connection to broker %s: %s\n", b.addr, err)
}
b.conn = nil
atomic.StoreInt32(&b.opened, 0)
return
}
}
b.done = make(chan bool)
b.responses = make(chan responsePromise, b.conf.Net.MaxOpenRequests-1)
if b.id >= 0 {
Logger.Printf("Connected to broker at %s (registered as #%d)\n", b.addr, b.id)
} else {
Logger.Printf("Connected to broker at %s (unregistered)\n", b.addr)
}
go withRecover(b.responseReceiver)
})
return nil
}
// Connected returns true if the broker is connected and false otherwise. If the broker is not
// connected but it had tried to connect, the error from that connection attempt is also returned.
func (b *Broker) Connected() (bool, error) {
b.lock.Lock()
defer b.lock.Unlock()
return b.conn != nil, b.connErr
}
func (b *Broker) Close() error {
b.lock.Lock()
defer b.lock.Unlock()
if b.conn == nil {
return ErrNotConnected
}
close(b.responses)
<-b.done
err := b.conn.Close()
b.conn = nil
b.connErr = nil
b.done = nil
b.responses = nil
if err == nil {
Logger.Printf("Closed connection to broker %s\n", b.addr)
} else {
Logger.Printf("Error while closing connection to broker %s: %s\n", b.addr, err)
}
atomic.StoreInt32(&b.opened, 0)
return err
}
// ID returns the broker ID retrieved from Kafka's metadata, or -1 if that is not known.
func (b *Broker) ID() int32 {
return b.id
}
// Addr returns the broker address as either retrieved from Kafka's metadata or passed to NewBroker.
func (b *Broker) Addr() string {
return b.addr
}
func (b *Broker) GetMetadata(request *MetadataRequest) (*MetadataResponse, error) {
response := new(MetadataResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) GetConsumerMetadata(request *ConsumerMetadataRequest) (*ConsumerMetadataResponse, error) {
response := new(ConsumerMetadataResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) GetAvailableOffsets(request *OffsetRequest) (*OffsetResponse, error) {
response := new(OffsetResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) Produce(request *ProduceRequest) (*ProduceResponse, error) {
var response *ProduceResponse
var err error
if request.RequiredAcks == NoResponse {
err = b.sendAndReceive(request, nil)
} else {
response = new(ProduceResponse)
err = b.sendAndReceive(request, response)
}
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) Fetch(request *FetchRequest) (*FetchResponse, error) {
response := new(FetchResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) CommitOffset(request *OffsetCommitRequest) (*OffsetCommitResponse, error) {
response := new(OffsetCommitResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) FetchOffset(request *OffsetFetchRequest) (*OffsetFetchResponse, error) {
response := new(OffsetFetchResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) JoinGroup(request *JoinGroupRequest) (*JoinGroupResponse, error) {
response := new(JoinGroupResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) SyncGroup(request *SyncGroupRequest) (*SyncGroupResponse, error) {
response := new(SyncGroupResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) LeaveGroup(request *LeaveGroupRequest) (*LeaveGroupResponse, error) {
response := new(LeaveGroupResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) Heartbeat(request *HeartbeatRequest) (*HeartbeatResponse, error) {
response := new(HeartbeatResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) ListGroups(request *ListGroupsRequest) (*ListGroupsResponse, error) {
response := new(ListGroupsResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) DescribeGroups(request *DescribeGroupsRequest) (*DescribeGroupsResponse, error) {
response := new(DescribeGroupsResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) ApiVersions(request *ApiVersionsRequest) (*ApiVersionsResponse, error) {
response := new(ApiVersionsResponse)
err := b.sendAndReceive(request, response)
if err != nil {
return nil, err
}
return response, nil
}
func (b *Broker) send(rb protocolBody, promiseResponse bool) (*responsePromise, error) {
b.lock.Lock()
defer b.lock.Unlock()
if b.conn == nil {
if b.connErr != nil {
return nil, b.connErr
}
return nil, ErrNotConnected
}
if !b.conf.Version.IsAtLeast(rb.requiredVersion()) {
return nil, ErrUnsupportedVersion
}
req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return nil, err
}
err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
if err != nil {
return nil, err
}
requestTime := time.Now()
bytes, err := b.conn.Write(buf)
b.updateOutgoingCommunicationMetrics(bytes)
if err != nil {
return nil, err
}
b.correlationID++
if !promiseResponse {
// Record request latency without the response
b.updateRequestLatencyMetrics(time.Since(requestTime))
return nil, nil
}
promise := responsePromise{requestTime, req.correlationID, make(chan []byte), make(chan error)}
b.responses <- promise
return &promise, nil
}
func (b *Broker) sendAndReceive(req protocolBody, res versionedDecoder) error {
promise, err := b.send(req, res != nil)
if err != nil {
return err
}
if promise == nil {
return nil
}
select {
case buf := <-promise.packets:
return versionedDecode(buf, res, req.version())
case err = <-promise.errors:
return err
}
}
func (b *Broker) decode(pd packetDecoder) (err error) {
b.id, err = pd.getInt32()
if err != nil {
return err
}
host, err := pd.getString()
if err != nil {
return err
}
port, err := pd.getInt32()
if err != nil {
return err
}
b.addr = net.JoinHostPort(host, fmt.Sprint(port))
if _, _, err := net.SplitHostPort(b.addr); err != nil {
return err
}
return nil
}
func (b *Broker) encode(pe packetEncoder) (err error) {
host, portstr, err := net.SplitHostPort(b.addr)
if err != nil {
return err
}
port, err := strconv.Atoi(portstr)
if err != nil {
return err
}
pe.putInt32(b.id)
err = pe.putString(host)
if err != nil {
return err
}
pe.putInt32(int32(port))
return nil
}
func (b *Broker) responseReceiver() {
var dead error
header := make([]byte, 8)
for response := range b.responses {
if dead != nil {
response.errors <- dead
continue
}
err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout))
if err != nil {
dead = err
response.errors <- err
continue
}
bytesReadHeader, err := io.ReadFull(b.conn, header)
requestLatency := time.Since(response.requestTime)
if err != nil {
b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
dead = err
response.errors <- err
continue
}
decodedHeader := responseHeader{}
err = decode(header, &decodedHeader)
if err != nil {
b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
dead = err
response.errors <- err
continue
}
if decodedHeader.correlationID != response.correlationID {
b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
// TODO if decoded ID < cur ID, discard until we catch up
// TODO if decoded ID > cur ID, save it so when cur ID catches up we have a response
dead = PacketDecodingError{fmt.Sprintf("correlation ID didn't match, wanted %d, got %d", response.correlationID, decodedHeader.correlationID)}
response.errors <- dead
continue
}
buf := make([]byte, decodedHeader.length-4)
bytesReadBody, err := io.ReadFull(b.conn, buf)
b.updateIncomingCommunicationMetrics(bytesReadHeader+bytesReadBody, requestLatency)
if err != nil {
dead = err
response.errors <- err
continue
}
response.packets <- buf
}
close(b.done)
}
func (b *Broker) sendAndReceiveSASLPlainHandshake() error {
rb := &SaslHandshakeRequest{"PLAIN"}
req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return err
}
err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
if err != nil {
return err
}
requestTime := time.Now()
bytes, err := b.conn.Write(buf)
b.updateOutgoingCommunicationMetrics(bytes)
if err != nil {
Logger.Printf("Failed to send SASL handshake %s: %s\n", b.addr, err.Error())
return err
}
b.correlationID++
//wait for the response
header := make([]byte, 8) // response header
_, err = io.ReadFull(b.conn, header)
if err != nil {
Logger.Printf("Failed to read SASL handshake header : %s\n", err.Error())
return err
}
length := binary.BigEndian.Uint32(header[:4])
payload := make([]byte, length-4)
n, err := io.ReadFull(b.conn, payload)
if err != nil {
Logger.Printf("Failed to read SASL handshake payload : %s\n", err.Error())
return err
}
b.updateIncomingCommunicationMetrics(n+8, time.Since(requestTime))
res := &SaslHandshakeResponse{}
err = versionedDecode(payload, res, 0)
if err != nil {
Logger.Printf("Failed to parse SASL handshake : %s\n", err.Error())
return err
}
if res.Err != ErrNoError {
Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
return res.Err
}
Logger.Print("Successful SASL handshake")
return nil
}
// Kafka 0.10.0 plans to support SASL Plain and Kerberos as per PR #812 (KIP-43)/(JIRA KAFKA-3149)
// Some hosted kafka services such as IBM Message Hub already offer SASL/PLAIN auth with Kafka 0.9
//
// In SASL Plain, Kafka expects the auth header to be in the following format
// Message format (from https://tools.ietf.org/html/rfc4616):
//
// message = [authzid] UTF8NUL authcid UTF8NUL passwd
// authcid = 1*SAFE ; MUST accept up to 255 octets
// authzid = 1*SAFE ; MUST accept up to 255 octets
// passwd = 1*SAFE ; MUST accept up to 255 octets
// UTF8NUL = %x00 ; UTF-8 encoded NUL character
//
// SAFE = UTF1 / UTF2 / UTF3 / UTF4
// ;; any UTF-8 encoded Unicode character except NUL
//
// When credentials are valid, Kafka returns a 4 byte array of null characters.
// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way
// of responding to bad credentials but thats how its being done today.
func (b *Broker) sendAndReceiveSASLPlainAuth() error {
if b.conf.Net.SASL.Handshake {
handshakeErr := b.sendAndReceiveSASLPlainHandshake()
if handshakeErr != nil {
Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
return handshakeErr
}
}
length := 1 + len(b.conf.Net.SASL.User) + 1 + len(b.conf.Net.SASL.Password)
authBytes := make([]byte, length+4) //4 byte length header + auth data
binary.BigEndian.PutUint32(authBytes, uint32(length))
copy(authBytes[4:], []byte("\x00"+b.conf.Net.SASL.User+"\x00"+b.conf.Net.SASL.Password))
err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
if err != nil {
Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error())
return err
}
requestTime := time.Now()
bytesWritten, err := b.conn.Write(authBytes)
b.updateOutgoingCommunicationMetrics(bytesWritten)
if err != nil {
Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
return err
}
header := make([]byte, 4)
n, err := io.ReadFull(b.conn, header)
b.updateIncomingCommunicationMetrics(n, time.Since(requestTime))
// If the credentials are valid, we would get a 4 byte response filled with null characters.
// Otherwise, the broker closes the connection and we get an EOF
if err != nil {
Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
return err
}
Logger.Printf("SASL authentication successful with broker %s:%v - %v\n", b.addr, n, header)
return nil
}
func (b *Broker) updateIncomingCommunicationMetrics(bytes int, requestLatency time.Duration) {
b.updateRequestLatencyMetrics(requestLatency)
b.responseRate.Mark(1)
if b.brokerResponseRate != nil {
b.brokerResponseRate.Mark(1)
}
responseSize := int64(bytes)
b.incomingByteRate.Mark(responseSize)
if b.brokerIncomingByteRate != nil {
b.brokerIncomingByteRate.Mark(responseSize)
}
b.responseSize.Update(responseSize)
if b.brokerResponseSize != nil {
b.brokerResponseSize.Update(responseSize)
}
}
func (b *Broker) updateRequestLatencyMetrics(requestLatency time.Duration) {
requestLatencyInMs := int64(requestLatency / time.Millisecond)
b.requestLatency.Update(requestLatencyInMs)
if b.brokerRequestLatency != nil {
b.brokerRequestLatency.Update(requestLatencyInMs)
}
}
func (b *Broker) updateOutgoingCommunicationMetrics(bytes int) {
b.requestRate.Mark(1)
if b.brokerRequestRate != nil {
b.brokerRequestRate.Mark(1)
}
requestSize := int64(bytes)
b.outgoingByteRate.Mark(requestSize)
if b.brokerOutgoingByteRate != nil {
b.brokerOutgoingByteRate.Mark(requestSize)
}
b.requestSize.Update(requestSize)
if b.brokerRequestSize != nil {
b.brokerRequestSize.Update(requestSize)
}
}