Skip to content

Support for re-authentication #492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 111 additions & 30 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package kafka

import (
"bufio"
"context"
"errors"
"fmt"
"github.com/segmentio/kafka-go/sasl"
"io"
"log"
"math"
"net"
"os"
Expand Down Expand Up @@ -81,6 +84,9 @@ type Conn struct {
apiVersions atomic.Value // apiVersionMap

transactionalID *string

authLock sync.RWMutex
cancelNextAuthentication chan struct{}
}

type apiVersionMap map[apiKey]ApiVersion
Expand Down Expand Up @@ -177,15 +183,16 @@ func NewConnWith(conn net.Conn, config ConnConfig) *Conn {
}

c := &Conn{
conn: conn,
rbuf: *bufio.NewReader(conn),
wbuf: *bufio.NewWriter(conn),
clientID: config.ClientID,
topic: config.Topic,
partition: int32(config.Partition),
offset: FirstOffset,
requiredAcks: -1,
transactionalID: emptyToNullable(config.TransactionalID),
conn: conn,
rbuf: *bufio.NewReader(conn),
wbuf: *bufio.NewWriter(conn),
clientID: config.ClientID,
topic: config.Topic,
partition: int32(config.Partition),
offset: FirstOffset,
requiredAcks: -1,
transactionalID: emptyToNullable(config.TransactionalID),
cancelNextAuthentication: make(chan struct{}),
}

c.wb.w = &c.wbuf
Expand Down Expand Up @@ -551,6 +558,10 @@ func (c *Conn) syncGroup(request syncGroupRequestV0) (syncGroupResponseV0, error

// Close closes the kafka connection.
func (c *Conn) Close() error {
select {
case c.cancelNextAuthentication <- struct{}{}:
default:
}
return c.conn.Close()
}

Expand Down Expand Up @@ -796,6 +807,9 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch {
return &Batch{err: dontExpectEOF(err)}
}

c.authLock.RLock()
defer c.authLock.RUnlock()

id, err := c.doRequest(&c.rdeadline, func(deadline time.Time, id int32) error {
now := time.Now()
var timeout time.Duration
Expand Down Expand Up @@ -1303,7 +1317,7 @@ func (c *Conn) concurrency() int {
return int(atomic.LoadInt32(&c.inflight))
}

func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func(time.Time, int) error) error {
func (c *Conn) doNoWaitForAuth(d *connDeadline, write func(time.Time, int32) error, read func(time.Time, int) error) error {
id, err := c.doRequest(d, write)
if err != nil {
return err
Expand All @@ -1318,7 +1332,7 @@ func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func
switch err.(type) {
case Error:
default:
c.conn.Close()
c.Close()
}
}

Expand All @@ -1327,6 +1341,12 @@ func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func
return err
}

func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func(time.Time, int) error) error {
c.authLock.RLock()
defer c.authLock.RUnlock()
return c.doNoWaitForAuth(d, write, read)
}

func (c *Conn) doRequest(d *connDeadline, write func(time.Time, int32) error) (id int32, err error) {
c.enter()
c.wlock.Lock()
Expand All @@ -1339,7 +1359,7 @@ func (c *Conn) doRequest(d *connDeadline, write func(time.Time, int32) error) (i
// When an error occurs there's no way to know if the connection is in a
// recoverable state so we're better off just giving up at this point to
// avoid any risk of corrupting the following operations.
c.conn.Close()
c.Close()
c.leave()
}

Expand Down Expand Up @@ -1411,6 +1431,9 @@ func (c *Conn) ApiVersions() ([]ApiVersion, error) {
deadline = &c.wdeadline
}

c.authLock.RLock()
defer c.authLock.RUnlock()

id, err := c.doRequest(deadline, func(_ time.Time, id int32) error {
h := requestHeader{
ApiKey: int16(apiVersions),
Expand Down Expand Up @@ -1532,18 +1555,14 @@ func (d *connDeadline) unsetConnWriteDeadline() {
// therefore the client should already know which mechanisms are supported.
//
// See http://kafka.apache.org/protocol.html#The_Messages_SaslHandshake
func (c *Conn) saslHandshake(mechanism string) error {
func (c *Conn) saslHandshake(mechanism string, version apiVersion) error {
// The wire format for V0 and V1 is identical, but the version
// number will affect how the SASL authentication
// challenge/responses are sent
var resp saslHandshakeResponseV0

version, err := c.negotiateVersion(saslHandshake, v0, v1)
if err != nil {
return err
}

err = c.writeOperation(
err := c.doNoWaitForAuth(
&c.wdeadline,
func(deadline time.Time, id int32) error {
return c.writeRequest(saslHandshake, version, id, &saslHandshakeRequestV0{Mechanism: mechanism})
},
Expand All @@ -1559,25 +1578,87 @@ func (c *Conn) saslHandshake(mechanism string) error {
return err
}

// performs all of the required requests to authenticate this
// connection. If any step fails, this function returns with an error. A nil
// error indicates successful authentication.
//
// In case of error, this function *does not* close the connection. That is the
// responsibility of the caller.
func (c *Conn) authenticateSASL(ctx context.Context, mechanism sasl.Mechanism, version apiVersion) error {
//Prevent other requests from being sent while re-authenticating
c.authLock.Lock()
defer c.authLock.Unlock()

if err := c.saslHandshake(mechanism.Name(), version); err != nil {
return err
}

sess, state, err := mechanism.Start(ctx)
if err != nil {
return err
}

var sessionLifeTimeMs int64
for completed := false; !completed; {
var challenge []byte
challenge, sessionLifeTimeMs, err = c.saslAuthenticate(state)
switch err {
case nil:
case io.EOF:
// the broker may communicate a failed exchange by closing the
// connection (esp. in the case where we're passing opaque sasl
// data over the wire since there's no protocol info).
return SASLAuthenticationFailed
default:
return err
}

completed, state, err = sess.Next(ctx, challenge)
if err != nil {
return err
}
}

if sessionLifeTimeMs > 0 {
// schedule re-authentication after 80% of the session lifetime elapsed
// maybe a minimum timeout should be implemented, in order to avoid cloging the application
// when a broker returns a session life time too short?
t := time.NewTimer(time.Duration(sessionLifeTimeMs*80/100) * time.Millisecond)
go func() {
select {
case <-t.C:
if err := c.authenticateSASL(ctx, mechanism, version); err != nil {
log.Printf("error authenticating connection: %v", err)
}
case <-c.cancelNextAuthentication:
}
}()
}

return nil
}

// saslAuthenticate sends the SASL authenticate message. This function must
// be immediately preceded by a successful saslHandshake.
//
// See http://kafka.apache.org/protocol.html#The_Messages_SaslAuthenticate
func (c *Conn) saslAuthenticate(data []byte) ([]byte, error) {
func (c *Conn) saslAuthenticate(data []byte) ([]byte, int64, error) {
// if we sent a v1 handshake, then we must encapsulate the authentication
// request in a saslAuthenticateRequest. otherwise, we read and write raw
// bytes.
version, err := c.negotiateVersion(saslHandshake, v0, v1)
if err != nil {
return nil, err
return nil, 0, err
}

if version == v1 {
var request = saslAuthenticateRequestV0{Data: data}
var response saslAuthenticateResponseV0
var request = saslAuthenticateRequestV1{Data: data}
var response saslAuthenticateResponseV1

err := c.writeOperation(
err := c.doNoWaitForAuth(
&c.wdeadline,
func(deadline time.Time, id int32) error {
return c.writeRequest(saslAuthenticate, v0, id, request)
return c.writeRequest(saslAuthenticate, v1, id, request)
},
func(deadline time.Time, size int) error {
return expectZeroSize(func() (remain int, err error) {
Expand All @@ -1588,24 +1669,24 @@ func (c *Conn) saslAuthenticate(data []byte) ([]byte, error) {
if err == nil && response.ErrorCode != 0 {
err = Error(response.ErrorCode)
}
return response.Data, err
return response.Data, response.SessionLifeTimeMs, err
}

// fall back to opaque bytes on the wire. the broker is expecting these if
// it just processed a v0 sasl handshake.
c.wb.writeInt32(int32(len(data)))
if _, err := c.wb.Write(data); err != nil {
return nil, err
return nil, 0, err
}
if err := c.wb.Flush(); err != nil {
return nil, err
return nil, 0, err
}

var respLen int32
if _, err := readInt32(&c.rbuf, 4, &respLen); err != nil {
return nil, err
return nil, 0, err
}

resp, _, err := readNewBytes(&c.rbuf, int(respLen), int(respLen))
return resp, err
return resp, 0, err
}
7 changes: 6 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,12 @@ func TestUnsupportedSASLMechanism(t *testing.T) {
}
defer conn.Close()

if err := conn.saslHandshake("FOO"); err != UnsupportedSASLMechanism {
version, err := conn.negotiateVersion(saslHandshake, v0, v1)
if err != nil {
t.Errorf("error negotiating version: %v", err)
}

if err := conn.saslHandshake("FOO", version); err != UnsupportedSASLMechanism {
t.Errorf("Expected UnsupportedSASLMechanism but got %v", err)
}
}
Expand Down
40 changes: 6 additions & 34 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package kafka
import (
"context"
"crypto/tls"
"io"
"net"
"strconv"
"strings"
Expand Down Expand Up @@ -273,7 +272,7 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C
conn := NewConnWith(c, connCfg)

if d.SASLMechanism != nil {
if err := d.authenticateSASL(ctx, conn); err != nil {
if err := d.authenticateSASLFirstTime(ctx, conn); err != nil {
_ = conn.Close()
return nil, err
}
Expand All @@ -282,42 +281,15 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C
return conn, nil
}

// authenticateSASL performs all of the required requests to authenticate this
// connection. If any step fails, this function returns with an error. A nil
// error indicates successful authentication.
//
// In case of error, this function *does not* close the connection. That is the
// responsibility of the caller.
func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil {
return err
}

sess, state, err := d.SASLMechanism.Start(ctx)
// negotiates the version for future SASL authentication procedures
// and attempts to authenticate the connection for the first time
func (d *Dialer) authenticateSASLFirstTime(ctx context.Context, conn *Conn) error {
version, err := conn.negotiateVersion(saslHandshake, v0, v1)
if err != nil {
return err
}

for completed := false; !completed; {
challenge, err := conn.saslAuthenticate(state)
switch err {
case nil:
case io.EOF:
// the broker may communicate a failed exchange by closing the
// connection (esp. in the case where we're passing opaque sasl
// data over the wire since there's no protocol info).
return SASLAuthenticationFailed
default:
return err
}

completed, state, err = sess.Next(ctx, challenge)
if err != nil {
return err
}
}

return nil
return conn.authenticateSASL(ctx, d.SASLMechanism, version)
}

func (d *Dialer) dialContext(ctx context.Context, network string, address string) (net.Conn, error) {
Expand Down
21 changes: 0 additions & 21 deletions go.sum

This file was deleted.

Loading