Skip to content

Commit c90ae64

Browse files
committed
add support custom tls config
1 parent d5affd5 commit c90ae64

File tree

2 files changed

+106
-9
lines changed

2 files changed

+106
-9
lines changed

doc.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,47 @@ Valid values for sslmode are:
6868
the server was signed by a trusted CA and the server host name
6969
matches the one in the certificate)
7070
71+
For support ssl key in memory, we extend sslmode. For example:
72+
73+
import (
74+
"crypto/tls"
75+
"crypto/x509"
76+
"io/ioutil"
77+
"log"
78+
79+
"github.com/lib/pq"
80+
)
81+
82+
func main() {
83+
rootCertPool := x509.NewCertPool()
84+
pem, err := ioutil.ReadFile("ca.crt")
85+
if err != nil {
86+
log.Fatal(err)
87+
}
88+
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
89+
log.Fatal("Failed to append PEM.")
90+
}
91+
clientCert := make([]tls.Certificate, 0, 1)
92+
certs, err := tls.LoadX509KeyPair("client1.crt", "client1.key")
93+
if err != nil {
94+
log.Fatal(err)
95+
}
96+
clientCert = append(clientCert, certs)
97+
err = pq.RegisterTLSConfig("custom", &tls.Config{
98+
RootCAs: rootCertPool,
99+
Certificates: clientCert,
100+
ServerName: "pq.example.com",
101+
})
102+
if err != nil {
103+
log.Fatal(err)
104+
}
105+
connStr := "host=pq.example.com port=5432 user=user1 dbname=pqgotest password=pqgotest sslmode=custom"
106+
db, err := sql.Open("postgres", connStr)
107+
if err != nil {
108+
log.Fatal(err)
109+
}
110+
}
111+
71112
See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
72113
for more information about connection string parameters.
73114

ssl.go

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,74 @@ package pq
33
import (
44
"crypto/tls"
55
"crypto/x509"
6+
"fmt"
67
"io/ioutil"
78
"net"
89
"os"
910
"os/user"
1011
"path/filepath"
1112
"strings"
13+
"sync"
1214
)
1315

16+
// Registry for custom tls.Configs
17+
var (
18+
tlsConfigLock sync.RWMutex
19+
tlsConfigRegistry map[string]*tls.Config
20+
)
21+
22+
func RegisterTLSConfig(key string, config *tls.Config) error {
23+
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "require" || strings.ToLower(key) == "verify-ca" || strings.ToLower(key) == "verify-full" || strings.ToLower(key) == "disable" {
24+
return fmt.Errorf("key '%s' is reserved", key)
25+
}
26+
27+
tlsConfigLock.Lock()
28+
if tlsConfigRegistry == nil {
29+
tlsConfigRegistry = make(map[string]*tls.Config)
30+
}
31+
32+
tlsConfigRegistry[key] = config
33+
tlsConfigLock.Unlock()
34+
return nil
35+
}
36+
37+
// DeregisterTLSConfig removes the tls.Config associated with key.
38+
func DeregisterTLSConfig(key string) {
39+
tlsConfigLock.Lock()
40+
if tlsConfigRegistry != nil {
41+
delete(tlsConfigRegistry, key)
42+
}
43+
tlsConfigLock.Unlock()
44+
}
45+
46+
func getTLSConfigClone(key string) (config *tls.Config) {
47+
tlsConfigLock.RLock()
48+
if v, ok := tlsConfigRegistry[key]; ok {
49+
config = v.Clone()
50+
}
51+
tlsConfigLock.RUnlock()
52+
return
53+
}
54+
55+
// Returns the bool value of the input.
56+
// The 2nd return value indicates if the input was a valid bool value
57+
func readBool(input string) (value bool, valid bool) {
58+
switch input {
59+
case "1", "true", "TRUE", "True":
60+
return true, true
61+
case "0", "false", "FALSE", "False":
62+
return false, true
63+
}
64+
65+
// Not a valid bool value
66+
return
67+
}
68+
1469
// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
1570
// related settings. The function is nil when no upgrade should take place.
1671
func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
1772
verifyCaOnly := false
18-
tlsConf := tls.Config{}
73+
tlsConf := &tls.Config{}
1974
switch mode := o["sslmode"]; mode {
2075
// "require" is the default.
2176
case "", "require":
@@ -48,7 +103,12 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
48103
case "disable":
49104
return nil, nil
50105
default:
51-
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
106+
{
107+
tlsConf = getTLSConfigClone(mode)
108+
if tlsConf == nil {
109+
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
110+
}
111+
}
52112
}
53113

54114
// Set Server Name Indication (SNI), if enabled by connection parameters.
@@ -61,11 +121,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
61121
tlsConf.ServerName = o["host"]
62122
}
63123

64-
err := sslClientCertificates(&tlsConf, o)
65-
if err != nil {
66-
return nil, err
67-
}
68-
err = sslCertificateAuthority(&tlsConf, o)
124+
err := sslClientCertificates(tlsConf, o)
69125
if err != nil {
70126
return nil, err
71127
}
@@ -78,9 +134,9 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
78134
tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient
79135

80136
return func(conn net.Conn) (net.Conn, error) {
81-
client := tls.Client(conn, &tlsConf)
137+
client := tls.Client(conn, tlsConf)
82138
if verifyCaOnly {
83-
err := sslVerifyCertificateAuthority(client, &tlsConf)
139+
err := sslVerifyCertificateAuthority(client, tlsConf)
84140
if err != nil {
85141
return nil, err
86142
}

0 commit comments

Comments
 (0)