Skip to content

Commit 6ecc178

Browse files
committed
Add keepalives support
1 parent 083382b commit 6ecc178

File tree

4 files changed

+115
-2
lines changed

4 files changed

+115
-2
lines changed

conn.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,8 @@ func isDriverSetting(key string) bool {
10681068
return true
10691069
case "fallback_application_name":
10701070
return true
1071+
case "keepalives", "keepalives_interval":
1072+
return true
10711073
case "connect_timeout":
10721074
return true
10731075
case "disable_prepared_binary_result":

connector.go

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@ import (
55
"database/sql/driver"
66
"errors"
77
"fmt"
8+
"net"
89
"os"
10+
"strconv"
911
"strings"
12+
"time"
1013
)
1114

1215
// Connector represents a fixed configuration for the pq driver with a given
@@ -107,9 +110,41 @@ func NewConnector(dsn string) (*Connector, error) {
107110
}
108111

109112
// SSL is not necessary or supported over UNIX domain sockets
110-
if network, _ := network(o); network == "unix" {
113+
ntw, _ := network(o)
114+
if ntw == "unix" {
111115
o["sslmode"] = "disable"
112116
}
113117

114-
return &Connector{opts: o, dialer: defaultDialer{}}, nil
118+
var d net.Dialer
119+
if ntw == "tcp" {
120+
d.KeepAlive, err = keepalive(o)
121+
if err != nil {
122+
return nil, err
123+
}
124+
}
125+
126+
return &Connector{opts: o, dialer: defaultDialer{d}}, nil
127+
}
128+
129+
// keepalive returns the interval between keep-alive probes controlled by keepalives_interval.
130+
// If zero, keep-alive probes are sent with a default value (see net.Dialer).
131+
// If negative, keep-alive probes are disabled.
132+
//
133+
// The keepalives parameter controls whether client-side TCP keepalives are used.
134+
// The default value is 1, meaning on, but you can change this to 0, meaning off, if keepalives are not wanted.
135+
func keepalive(o values) (time.Duration, error) {
136+
v, ok := o["keepalives"]
137+
if ok && v == "0" {
138+
return -1, nil
139+
}
140+
141+
if v, ok = o["keepalives_interval"]; !ok {
142+
return 0, nil
143+
}
144+
145+
keepintvl, err := strconv.ParseInt(v, 10, 0)
146+
if err != nil {
147+
return 0, fmt.Errorf("invalid value for parameter keepalives_interval: %w", err)
148+
}
149+
return time.Duration(keepintvl) * time.Second, nil
115150
}

connector_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ import (
66
"context"
77
"database/sql"
88
"database/sql/driver"
9+
"errors"
10+
"strconv"
911
"testing"
12+
"time"
1013
)
1114

1215
func TestNewConnector_WorksWithOpenDB(t *testing.T) {
@@ -65,3 +68,70 @@ func TestNewConnector_Driver(t *testing.T) {
6568
}
6669
txn.Rollback()
6770
}
71+
72+
func TestNewConnectorKeepalive(t *testing.T) {
73+
c, err := NewConnector("keepalives=1 keepalives_interval=10")
74+
if err != nil {
75+
t.Fatal(err)
76+
}
77+
db := sql.OpenDB(c)
78+
defer db.Close()
79+
// database/sql might not call our Open at all unless we do something with
80+
// the connection
81+
txn, err := db.Begin()
82+
if err != nil {
83+
t.Fatal(err)
84+
}
85+
txn.Rollback()
86+
87+
d, _ := c.dialer.(defaultDialer)
88+
want := 10 * time.Second
89+
if want != d.d.KeepAlive {
90+
t.Fatalf("expected: %v, got: %v", want, d.d.KeepAlive)
91+
}
92+
}
93+
94+
func TestKeepalive(t *testing.T) {
95+
var tt = map[string]struct {
96+
input values
97+
want time.Duration
98+
}{
99+
"keepalives on": {values{"keepalives": "1"}, 0},
100+
"keepalives on by default": {nil, 0},
101+
"keepalives off": {values{"keepalives": "0"}, -1},
102+
"keepalives_interval 5 seconds": {values{"keepalives_interval": "5"}, 5 * time.Second},
103+
"keepalives_interval default": {values{"keepalives_interval": "0"}, 0},
104+
"keepalives_interval off": {values{"keepalives_interval": "-1"}, -1 * time.Second},
105+
}
106+
107+
for name, tc := range tt {
108+
t.Run(name, func(t *testing.T) {
109+
got, err := keepalive(tc.input)
110+
if err != nil {
111+
t.Fatal(err)
112+
}
113+
if tc.want != got {
114+
t.Fatalf("expected: %v, got: %v", tc.want, got)
115+
}
116+
})
117+
}
118+
}
119+
120+
func TestKeepaliveError(t *testing.T) {
121+
var tt = map[string]struct {
122+
input values
123+
want error
124+
}{
125+
"keepalives_interval whitespace": {values{"keepalives_interval": " "}, strconv.ErrSyntax},
126+
"keepalives_interval float": {values{"keepalives_interval": "1.1"}, strconv.ErrSyntax},
127+
}
128+
129+
for name, tc := range tt {
130+
t.Run(name, func(t *testing.T) {
131+
_, err := keepalive(tc.input)
132+
if !errors.Is(err, tc.want) {
133+
t.Fatalf("expected: %v, got: %v", tc.want, err)
134+
}
135+
})
136+
}
137+
}

doc.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ supported:
5151
* sslmode - Whether or not to use SSL (default is require, this is not
5252
the default for libpq)
5353
* fallback_application_name - An application_name to fall back to if one isn't provided.
54+
* keepalives - Whether or not to use client-side TCP keepalives
55+
(the default value is 1, meaning on, but you can change this to 0, meaning off)
56+
* keepalives_interval - The number of seconds after which a TCP keepalive message
57+
that is not acknowledged by the server should be retransmitted.
58+
If zero or not specified, keep-alive probes are sent with a default value (see net.Dialer).
59+
If negative, keep-alive probes are disabled.
5460
* connect_timeout - Maximum wait for connection, in seconds. Zero or
5561
not specified means wait indefinitely.
5662
* sslcert - Cert file location. The file must contain PEM encoded data.

0 commit comments

Comments
 (0)