@@ -3,19 +3,74 @@ package pq
3
3
import (
4
4
"crypto/tls"
5
5
"crypto/x509"
6
+ "fmt"
6
7
"io/ioutil"
7
8
"net"
8
9
"os"
9
10
"os/user"
10
11
"path/filepath"
11
12
"strings"
13
+ "sync"
12
14
)
13
15
16
+ // Registry for custom tls.Configs
17
+ var (
18
+ tlsConfigLock sync.RWMutex
19
+ tlsConfigRegistry map [string ]* tls.Config
20
+ )
21
+
22
+ func RegisterTLSConfig (key string , config * tls.Config ) error {
23
+ if _ , isBool := readBool (key ); isBool || strings .ToLower (key ) == "require" || strings .ToLower (key ) == "verify-ca" || strings .ToLower (key ) == "verify-full" || strings .ToLower (key ) == "disable" {
24
+ return fmt .Errorf ("key '%s' is reserved" , key )
25
+ }
26
+
27
+ tlsConfigLock .Lock ()
28
+ if tlsConfigRegistry == nil {
29
+ tlsConfigRegistry = make (map [string ]* tls.Config )
30
+ }
31
+
32
+ tlsConfigRegistry [key ] = config
33
+ tlsConfigLock .Unlock ()
34
+ return nil
35
+ }
36
+
37
+ // DeregisterTLSConfig removes the tls.Config associated with key.
38
+ func DeregisterTLSConfig (key string ) {
39
+ tlsConfigLock .Lock ()
40
+ if tlsConfigRegistry != nil {
41
+ delete (tlsConfigRegistry , key )
42
+ }
43
+ tlsConfigLock .Unlock ()
44
+ }
45
+
46
+ func getTLSConfigClone (key string ) (config * tls.Config ) {
47
+ tlsConfigLock .RLock ()
48
+ if v , ok := tlsConfigRegistry [key ]; ok {
49
+ config = v .Clone ()
50
+ }
51
+ tlsConfigLock .RUnlock ()
52
+ return
53
+ }
54
+
55
+ // Returns the bool value of the input.
56
+ // The 2nd return value indicates if the input was a valid bool value
57
+ func readBool (input string ) (value bool , valid bool ) {
58
+ switch input {
59
+ case "1" , "true" , "TRUE" , "True" :
60
+ return true , true
61
+ case "0" , "false" , "FALSE" , "False" :
62
+ return false , true
63
+ }
64
+
65
+ // Not a valid bool value
66
+ return
67
+ }
68
+
14
69
// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
15
70
// related settings. The function is nil when no upgrade should take place.
16
71
func ssl (o values ) (func (net.Conn ) (net.Conn , error ), error ) {
17
72
verifyCaOnly := false
18
- tlsConf := tls.Config {}
73
+ tlsConf := & tls.Config {}
19
74
switch mode := o ["sslmode" ]; mode {
20
75
// "require" is the default.
21
76
case "" , "require" :
@@ -48,7 +103,12 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
48
103
case "disable" :
49
104
return nil , nil
50
105
default :
51
- return nil , fmterrorf (`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported` , mode )
106
+ {
107
+ tlsConf = getTLSConfigClone (mode )
108
+ if tlsConf == nil {
109
+ return nil , fmterrorf (`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported` , mode )
110
+ }
111
+ }
52
112
}
53
113
54
114
// Set Server Name Indication (SNI), if enabled by connection parameters.
@@ -61,11 +121,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
61
121
tlsConf .ServerName = o ["host" ]
62
122
}
63
123
64
- err := sslClientCertificates (& tlsConf , o )
65
- if err != nil {
66
- return nil , err
67
- }
68
- err = sslCertificateAuthority (& tlsConf , o )
124
+ err := sslClientCertificates (tlsConf , o )
69
125
if err != nil {
70
126
return nil , err
71
127
}
@@ -78,9 +134,9 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
78
134
tlsConf .Renegotiation = tls .RenegotiateFreelyAsClient
79
135
80
136
return func (conn net.Conn ) (net.Conn , error ) {
81
- client := tls .Client (conn , & tlsConf )
137
+ client := tls .Client (conn , tlsConf )
82
138
if verifyCaOnly {
83
- err := sslVerifyCertificateAuthority (client , & tlsConf )
139
+ err := sslVerifyCertificateAuthority (client , tlsConf )
84
140
if err != nil {
85
141
return nil , err
86
142
}
0 commit comments