diff --git a/conn.go b/conn.go index ff229f76e..2978e751e 100644 --- a/conn.go +++ b/conn.go @@ -2,9 +2,12 @@ package kafka import ( "bufio" + "context" "errors" "fmt" + "github.com/segmentio/kafka-go/sasl" "io" + "log" "math" "net" "os" @@ -81,6 +84,9 @@ type Conn struct { apiVersions atomic.Value // apiVersionMap transactionalID *string + + authLock sync.RWMutex + cancelNextAuthentication chan struct{} } type apiVersionMap map[apiKey]ApiVersion @@ -177,15 +183,16 @@ func NewConnWith(conn net.Conn, config ConnConfig) *Conn { } c := &Conn{ - conn: conn, - rbuf: *bufio.NewReader(conn), - wbuf: *bufio.NewWriter(conn), - clientID: config.ClientID, - topic: config.Topic, - partition: int32(config.Partition), - offset: FirstOffset, - requiredAcks: -1, - transactionalID: emptyToNullable(config.TransactionalID), + conn: conn, + rbuf: *bufio.NewReader(conn), + wbuf: *bufio.NewWriter(conn), + clientID: config.ClientID, + topic: config.Topic, + partition: int32(config.Partition), + offset: FirstOffset, + requiredAcks: -1, + transactionalID: emptyToNullable(config.TransactionalID), + cancelNextAuthentication: make(chan struct{}), } c.wb.w = &c.wbuf @@ -551,6 +558,10 @@ func (c *Conn) syncGroup(request syncGroupRequestV0) (syncGroupResponseV0, error // Close closes the kafka connection. func (c *Conn) Close() error { + select { + case c.cancelNextAuthentication <- struct{}{}: + default: + } return c.conn.Close() } @@ -796,6 +807,9 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch { return &Batch{err: dontExpectEOF(err)} } + c.authLock.RLock() + defer c.authLock.RUnlock() + id, err := c.doRequest(&c.rdeadline, func(deadline time.Time, id int32) error { now := time.Now() var timeout time.Duration @@ -1303,7 +1317,7 @@ func (c *Conn) concurrency() int { return int(atomic.LoadInt32(&c.inflight)) } -func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func(time.Time, int) error) error { +func (c *Conn) doNoWaitForAuth(d *connDeadline, write func(time.Time, int32) error, read func(time.Time, int) error) error { id, err := c.doRequest(d, write) if err != nil { return err @@ -1318,7 +1332,7 @@ func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func switch err.(type) { case Error: default: - c.conn.Close() + c.Close() } } @@ -1327,6 +1341,12 @@ func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func return err } +func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func(time.Time, int) error) error { + c.authLock.RLock() + defer c.authLock.RUnlock() + return c.doNoWaitForAuth(d, write, read) +} + func (c *Conn) doRequest(d *connDeadline, write func(time.Time, int32) error) (id int32, err error) { c.enter() c.wlock.Lock() @@ -1339,7 +1359,7 @@ func (c *Conn) doRequest(d *connDeadline, write func(time.Time, int32) error) (i // When an error occurs there's no way to know if the connection is in a // recoverable state so we're better off just giving up at this point to // avoid any risk of corrupting the following operations. - c.conn.Close() + c.Close() c.leave() } @@ -1411,6 +1431,9 @@ func (c *Conn) ApiVersions() ([]ApiVersion, error) { deadline = &c.wdeadline } + c.authLock.RLock() + defer c.authLock.RUnlock() + id, err := c.doRequest(deadline, func(_ time.Time, id int32) error { h := requestHeader{ ApiKey: int16(apiVersions), @@ -1532,18 +1555,14 @@ func (d *connDeadline) unsetConnWriteDeadline() { // therefore the client should already know which mechanisms are supported. // // See http://kafka.apache.org/protocol.html#The_Messages_SaslHandshake -func (c *Conn) saslHandshake(mechanism string) error { +func (c *Conn) saslHandshake(mechanism string, version apiVersion) error { // The wire format for V0 and V1 is identical, but the version // number will affect how the SASL authentication // challenge/responses are sent var resp saslHandshakeResponseV0 - version, err := c.negotiateVersion(saslHandshake, v0, v1) - if err != nil { - return err - } - - err = c.writeOperation( + err := c.doNoWaitForAuth( + &c.wdeadline, func(deadline time.Time, id int32) error { return c.writeRequest(saslHandshake, version, id, &saslHandshakeRequestV0{Mechanism: mechanism}) }, @@ -1559,25 +1578,87 @@ func (c *Conn) saslHandshake(mechanism string) error { return err } +// performs all of the required requests to authenticate this +// connection. If any step fails, this function returns with an error. A nil +// error indicates successful authentication. +// +// In case of error, this function *does not* close the connection. That is the +// responsibility of the caller. +func (c *Conn) authenticateSASL(ctx context.Context, mechanism sasl.Mechanism, version apiVersion) error { + //Prevent other requests from being sent while re-authenticating + c.authLock.Lock() + defer c.authLock.Unlock() + + if err := c.saslHandshake(mechanism.Name(), version); err != nil { + return err + } + + sess, state, err := mechanism.Start(ctx) + if err != nil { + return err + } + + var sessionLifeTimeMs int64 + for completed := false; !completed; { + var challenge []byte + challenge, sessionLifeTimeMs, err = c.saslAuthenticate(state) + switch err { + case nil: + case io.EOF: + // the broker may communicate a failed exchange by closing the + // connection (esp. in the case where we're passing opaque sasl + // data over the wire since there's no protocol info). + return SASLAuthenticationFailed + default: + return err + } + + completed, state, err = sess.Next(ctx, challenge) + if err != nil { + return err + } + } + + if sessionLifeTimeMs > 0 { + // schedule re-authentication after 80% of the session lifetime elapsed + // maybe a minimum timeout should be implemented, in order to avoid cloging the application + // when a broker returns a session life time too short? + t := time.NewTimer(time.Duration(sessionLifeTimeMs*80/100) * time.Millisecond) + go func() { + select { + case <-t.C: + if err := c.authenticateSASL(ctx, mechanism, version); err != nil { + log.Printf("error authenticating connection: %v", err) + } + case <-c.cancelNextAuthentication: + } + }() + } + + return nil +} + // saslAuthenticate sends the SASL authenticate message. This function must // be immediately preceded by a successful saslHandshake. // // See http://kafka.apache.org/protocol.html#The_Messages_SaslAuthenticate -func (c *Conn) saslAuthenticate(data []byte) ([]byte, error) { +func (c *Conn) saslAuthenticate(data []byte) ([]byte, int64, error) { // if we sent a v1 handshake, then we must encapsulate the authentication // request in a saslAuthenticateRequest. otherwise, we read and write raw // bytes. version, err := c.negotiateVersion(saslHandshake, v0, v1) if err != nil { - return nil, err + return nil, 0, err } + if version == v1 { - var request = saslAuthenticateRequestV0{Data: data} - var response saslAuthenticateResponseV0 + var request = saslAuthenticateRequestV1{Data: data} + var response saslAuthenticateResponseV1 - err := c.writeOperation( + err := c.doNoWaitForAuth( + &c.wdeadline, func(deadline time.Time, id int32) error { - return c.writeRequest(saslAuthenticate, v0, id, request) + return c.writeRequest(saslAuthenticate, v1, id, request) }, func(deadline time.Time, size int) error { return expectZeroSize(func() (remain int, err error) { @@ -1588,24 +1669,24 @@ func (c *Conn) saslAuthenticate(data []byte) ([]byte, error) { if err == nil && response.ErrorCode != 0 { err = Error(response.ErrorCode) } - return response.Data, err + return response.Data, response.SessionLifeTimeMs, err } // fall back to opaque bytes on the wire. the broker is expecting these if // it just processed a v0 sasl handshake. c.wb.writeInt32(int32(len(data))) if _, err := c.wb.Write(data); err != nil { - return nil, err + return nil, 0, err } if err := c.wb.Flush(); err != nil { - return nil, err + return nil, 0, err } var respLen int32 if _, err := readInt32(&c.rbuf, 4, &respLen); err != nil { - return nil, err + return nil, 0, err } resp, _, err := readNewBytes(&c.rbuf, int(respLen), int(respLen)) - return resp, err + return resp, 0, err } diff --git a/conn_test.go b/conn_test.go index 535e4b791..71aaade32 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1104,7 +1104,12 @@ func TestUnsupportedSASLMechanism(t *testing.T) { } defer conn.Close() - if err := conn.saslHandshake("FOO"); err != UnsupportedSASLMechanism { + version, err := conn.negotiateVersion(saslHandshake, v0, v1) + if err != nil { + t.Errorf("error negotiating version: %v", err) + } + + if err := conn.saslHandshake("FOO", version); err != UnsupportedSASLMechanism { t.Errorf("Expected UnsupportedSASLMechanism but got %v", err) } } diff --git a/dialer.go b/dialer.go index 35eb080cc..964cee806 100644 --- a/dialer.go +++ b/dialer.go @@ -3,7 +3,6 @@ package kafka import ( "context" "crypto/tls" - "io" "net" "strconv" "strings" @@ -273,7 +272,7 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C conn := NewConnWith(c, connCfg) if d.SASLMechanism != nil { - if err := d.authenticateSASL(ctx, conn); err != nil { + if err := d.authenticateSASLFirstTime(ctx, conn); err != nil { _ = conn.Close() return nil, err } @@ -282,42 +281,15 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C return conn, nil } -// authenticateSASL performs all of the required requests to authenticate this -// connection. If any step fails, this function returns with an error. A nil -// error indicates successful authentication. -// -// In case of error, this function *does not* close the connection. That is the -// responsibility of the caller. -func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { - if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil { - return err - } - - sess, state, err := d.SASLMechanism.Start(ctx) +// negotiates the version for future SASL authentication procedures +// and attempts to authenticate the connection for the first time +func (d *Dialer) authenticateSASLFirstTime(ctx context.Context, conn *Conn) error { + version, err := conn.negotiateVersion(saslHandshake, v0, v1) if err != nil { return err } - for completed := false; !completed; { - challenge, err := conn.saslAuthenticate(state) - switch err { - case nil: - case io.EOF: - // the broker may communicate a failed exchange by closing the - // connection (esp. in the case where we're passing opaque sasl - // data over the wire since there's no protocol info). - return SASLAuthenticationFailed - default: - return err - } - - completed, state, err = sess.Next(ctx, challenge) - if err != nil { - return err - } - } - - return nil + return conn.authenticateSASL(ctx, d.SASLMechanism, version) } func (d *Dialer) dialContext(ctx context.Context, network string, address string) (net.Conn, error) { diff --git a/go.sum b/go.sum deleted file mode 100644 index 6f0a37873..000000000 --- a/go.sum +++ /dev/null @@ -1,21 +0,0 @@ -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/klauspost/compress v1.9.8 h1:VMAMUUOh+gaxKTMk+zqbjsSjsIcUcL/LF4o63i82QyA= -github.com/klauspost/compress v1.9.8/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= -github.com/pierrec/lz4 v2.0.5+incompatible h1:2xWsjqPFWcplujydGg4WmhC/6fZqK42wMM8aXeqhl0I= -github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk= -github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= -github.com/xdg/stringprep v1.0.0 h1:d9X0esnoa3dFsV0FG35rAT0RIhYFlPq7MiP+DW89La0= -github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284 h1:rlLehGeYg6jfoyz/eDqDU1iRXLKfR42nnNh57ytKEWo= -golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/saslauthenticate.go b/saslauthenticate.go index ad1292918..fef4fb346 100644 --- a/saslauthenticate.go +++ b/saslauthenticate.go @@ -4,43 +4,45 @@ import ( "bufio" ) -type saslAuthenticateRequestV0 struct { +type saslAuthenticateRequestV1 struct { // Data holds the SASL payload Data []byte } -func (t saslAuthenticateRequestV0) size() int32 { +func (t saslAuthenticateRequestV1) size() int32 { return sizeofBytes(t.Data) } -func (t *saslAuthenticateRequestV0) readFrom(r *bufio.Reader, sz int) (remain int, err error) { +func (t *saslAuthenticateRequestV1) readFrom(r *bufio.Reader, sz int) (remain int, err error) { return readBytes(r, sz, &t.Data) } -func (t saslAuthenticateRequestV0) writeTo(wb *writeBuffer) { +func (t saslAuthenticateRequestV1) writeTo(wb *writeBuffer) { wb.writeBytes(t.Data) } -type saslAuthenticateResponseV0 struct { - // ErrorCode holds response error code +type saslAuthenticateResponseV1 struct { ErrorCode int16 ErrorMessage string Data []byte + + SessionLifeTimeMs int64 } -func (t saslAuthenticateResponseV0) size() int32 { - return sizeofInt16(t.ErrorCode) + sizeofString(t.ErrorMessage) + sizeofBytes(t.Data) +func (t saslAuthenticateResponseV1) size() int32 { + return sizeofInt16(t.ErrorCode) + sizeofString(t.ErrorMessage) + sizeofBytes(t.Data) + sizeofInt64(t.SessionLifeTimeMs) } -func (t saslAuthenticateResponseV0) writeTo(wb *writeBuffer) { +func (t saslAuthenticateResponseV1) writeTo(wb *writeBuffer) { wb.writeInt16(t.ErrorCode) wb.writeString(t.ErrorMessage) wb.writeBytes(t.Data) + wb.writeInt64(t.SessionLifeTimeMs) } -func (t *saslAuthenticateResponseV0) readFrom(r *bufio.Reader, sz int) (remain int, err error) { +func (t *saslAuthenticateResponseV1) readFrom(r *bufio.Reader, sz int) (remain int, err error) { if remain, err = readInt16(r, sz, &t.ErrorCode); err != nil { return } @@ -50,5 +52,9 @@ func (t *saslAuthenticateResponseV0) readFrom(r *bufio.Reader, sz int) (remain i if remain, err = readBytes(r, remain, &t.Data); err != nil { return } + if remain, err = readInt64(r, remain, &t.SessionLifeTimeMs); err != nil { + return + } + return } diff --git a/saslauthenticate_test.go b/saslauthenticate_test.go index 89a33e3da..5eb660402 100644 --- a/saslauthenticate_test.go +++ b/saslauthenticate_test.go @@ -7,8 +7,8 @@ import ( "testing" ) -func TestSASLAuthenticateRequestV0(t *testing.T) { - item := saslAuthenticateRequestV0{ +func TestSASLAuthenticateRequestV1(t *testing.T) { + item := saslAuthenticateRequestV1{ Data: []byte("\x00user\x00pass"), } @@ -16,7 +16,7 @@ func TestSASLAuthenticateRequestV0(t *testing.T) { w := &writeBuffer{w: b} item.writeTo(w) - var found saslAuthenticateRequestV0 + var found saslAuthenticateRequestV1 remain, err := (&found).readFrom(bufio.NewReader(b), b.Len()) if err != nil { t.Error(err) @@ -32,18 +32,19 @@ func TestSASLAuthenticateRequestV0(t *testing.T) { } } -func TestSASLAuthenticateResponseV0(t *testing.T) { - item := saslAuthenticateResponseV0{ - ErrorCode: 2, - ErrorMessage: "Message", - Data: []byte("bytes"), +func TestSASLAuthenticateResponseV1(t *testing.T) { + item := saslAuthenticateResponseV1{ + ErrorCode: 2, + ErrorMessage: "Message", + Data: []byte("bytes"), + SessionLifeTimeMs: 1000, } b := bytes.NewBuffer(nil) w := &writeBuffer{w: b} item.writeTo(w) - var found saslAuthenticateResponseV0 + var found saslAuthenticateResponseV1 remain, err := (&found).readFrom(bufio.NewReader(b), b.Len()) if err != nil { t.Error(err)