Skip to content

Commit dfa5a23

Browse files
authored
Merge pull request #4 from QXQZX/master
修复子协程和其他协程关闭问题 #3
2 parents e696552 + 0a12803 commit dfa5a23

File tree

2 files changed

+137
-98
lines changed

2 files changed

+137
-98
lines changed

client/client.go

+62-43
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
package main
22

33
import (
4+
"context"
45
"flag"
56
"fmt"
67
"io"
78
"net"
8-
"runtime"
99
"strings"
1010
"time"
1111
)
@@ -34,42 +34,50 @@ type server struct {
3434
}
3535

3636
// 从Server端读取数据
37-
func (s *server) Read() {
37+
func (s *server) Read(ctx context.Context) {
3838
// 如果10秒钟内没有消息传输,则Read函数会返回一个timeout的错误
3939
_ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 10))
4040
for {
41-
data := make([]byte, 10240)
42-
n, err := s.conn.Read(data)
43-
if err != nil && err != io.EOF {
44-
// 读取超时,发送一个心跳包过去
45-
if strings.Contains(err.Error(), "timeout") {
46-
// 3秒发一次心跳
47-
_ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
48-
s.conn.Write([]byte("pi"))
49-
continue
41+
select {
42+
case <-ctx.Done():
43+
return
44+
default:
45+
data := make([]byte, 10240)
46+
n, err := s.conn.Read(data)
47+
if err != nil && err != io.EOF {
48+
// 读取超时,发送一个心跳包过去
49+
if strings.Contains(err.Error(), "timeout") {
50+
// 3秒发一次心跳
51+
_ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
52+
s.conn.Write([]byte("pi"))
53+
continue
54+
}
55+
fmt.Println("从server读取数据失败, ", err.Error())
56+
s.exit <- err
57+
return
5058
}
51-
fmt.Println("从server读取数据失败, ", err.Error())
52-
s.exit <- err
53-
runtime.Goexit()
54-
}
5559

56-
// 如果收到心跳包, 则跳过
57-
if data[0] == 'p' && data[1] == 'i' {
58-
fmt.Println("client收到心跳包")
59-
continue
60+
// 如果收到心跳包, 则跳过
61+
if data[0] == 'p' && data[1] == 'i' {
62+
fmt.Println("client收到心跳包")
63+
continue
64+
}
65+
s.read <- data[:n]
6066
}
61-
s.read <- data[:n]
6267
}
6368
}
6469

6570
// 将数据写入到Server端
66-
func (s *server) Write() {
71+
func (s *server) Write(ctx context.Context) {
6772
for {
6873
select {
74+
case <-ctx.Done():
75+
return
6976
case data := <-s.write:
7077
_, err := s.conn.Write(data)
7178
if err != nil && err != io.EOF {
7279
s.exit <- err
80+
return
7381
}
7482
}
7583
}
@@ -84,25 +92,34 @@ type local struct {
8492
exit chan error
8593
}
8694

87-
func (l *local) Read() {
95+
func (l *local) Read(ctx context.Context) {
8896

8997
for {
90-
data := make([]byte, 10240)
91-
n, err := l.conn.Read(data)
92-
if err != nil {
93-
l.exit <- err
98+
select {
99+
case <-ctx.Done():
100+
return
101+
default:
102+
data := make([]byte, 10240)
103+
n, err := l.conn.Read(data)
104+
if err != nil {
105+
l.exit <- err
106+
return
107+
}
108+
l.read <- data[:n]
94109
}
95-
l.read <- data[:n]
96110
}
97111
}
98112

99-
func (l *local) Write() {
113+
func (l *local) Write(ctx context.Context) {
100114
for {
101115
select {
116+
case <-ctx.Done():
117+
return
102118
case data := <-l.write:
103119
_, err := l.conn.Write(data)
104120
if err != nil {
105121
l.exit <- err
122+
return
106123
}
107124
}
108125
}
@@ -127,20 +144,19 @@ func main() {
127144
reConn: make(chan bool),
128145
}
129146

130-
go server.Read()
131-
go server.Write()
132-
133147
go handle(server)
134-
135148
<-server.reConn
136-
_ = server.conn.Close()
149+
//_ = server.conn.Close()
137150
}
138151

139152
}
140153

141154
func handle(server *server) {
142155
// 等待server端发来的信息,也就是说user来请求server了
143-
data := <-server.read
156+
ctx, cancel := context.WithCancel(context.Background())
157+
158+
go server.Read(ctx)
159+
go server.Write(ctx)
144160

145161
localConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
146162
if err != nil {
@@ -154,10 +170,14 @@ func handle(server *server) {
154170
exit: make(chan error),
155171
}
156172

157-
go local.Read()
158-
go local.Write()
173+
go local.Read(ctx)
174+
go local.Write(ctx)
159175

160-
local.write <- data
176+
defer func() {
177+
_ = server.conn.Close()
178+
_ = local.conn.Close()
179+
server.reConn <- true
180+
}()
161181

162182
for {
163183
select {
@@ -169,13 +189,12 @@ func handle(server *server) {
169189

170190
case err := <-server.exit:
171191
fmt.Printf("server have err: %s", err.Error())
172-
_ = server.conn.Close()
173-
_ = local.conn.Close()
174-
server.reConn <- true
175-
192+
cancel()
193+
return
176194
case err := <-local.exit:
177-
fmt.Printf("server have err: %s", err.Error())
178-
_ = local.conn.Close()
195+
fmt.Printf("local have err: %s", err.Error())
196+
cancel()
197+
return
179198
}
180199
}
181200
}

server/server.go

+75-55
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
package main
22

33
import (
4+
"context"
45
"flag"
56
"fmt"
67
"io"
78
"net"
8-
"runtime"
99
"strings"
1010
"time"
1111
)
@@ -32,40 +32,49 @@ type client struct {
3232
}
3333

3434
// 从Client端读取数据
35-
func (c *client) Read() {
35+
func (c *client) Read(ctx context.Context) {
3636
// 如果10秒钟内没有消息传输,则Read函数会返回一个timeout的错误
3737
_ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 10))
3838
for {
39-
data := make([]byte, 10240)
40-
n, err := c.conn.Read(data)
41-
if err != nil && err != io.EOF {
42-
if strings.Contains(err.Error(), "timeout") {
43-
// 设置读取时间为3秒,3秒后若读取不到, 则err会抛出timeout,然后发送心跳
44-
_ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
45-
c.conn.Write([]byte("pi"))
46-
continue
39+
select {
40+
case <-ctx.Done():
41+
return
42+
default:
43+
data := make([]byte, 10240)
44+
n, err := c.conn.Read(data)
45+
if err != nil && err != io.EOF {
46+
if strings.Contains(err.Error(), "timeout") {
47+
// 设置读取时间为3秒,3秒后若读取不到, 则err会抛出timeout,然后发送心跳
48+
_ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
49+
c.conn.Write([]byte("pi"))
50+
continue
51+
}
52+
fmt.Println("读取出现错误...")
53+
c.exit <- err
54+
return
4755
}
48-
fmt.Println("读取出现错误...")
49-
c.exit <- err
50-
}
5156

52-
// 收到心跳包,则跳过
53-
if data[0] == 'p' && data[1] == 'i' {
54-
fmt.Println("server收到心跳包")
55-
continue
57+
// 收到心跳包,则跳过
58+
if data[0] == 'p' && data[1] == 'i' {
59+
fmt.Println("server收到心跳包")
60+
continue
61+
}
62+
c.read <- data[:n]
5663
}
57-
c.read <- data[:n]
5864
}
5965
}
6066

6167
// 将数据写入到Client端
62-
func (c *client) Write() {
68+
func (c *client) Write(ctx context.Context) {
6369
for {
6470
select {
71+
case <-ctx.Done():
72+
return
6573
case data := <-c.write:
6674
_, err := c.conn.Write(data)
6775
if err != nil && err != io.EOF {
6876
c.exit <- err
77+
return
6978
}
7079
}
7180
}
@@ -81,26 +90,35 @@ type user struct {
8190
}
8291

8392
// 从User端读取数据
84-
func (u *user) Read() {
93+
func (u *user) Read(ctx context.Context) {
8594
_ = u.conn.SetReadDeadline(time.Now().Add(time.Second * 200))
8695
for {
87-
data := make([]byte, 10240)
88-
n, err := u.conn.Read(data)
89-
if err != nil && err != io.EOF {
90-
u.exit <- err
96+
select {
97+
case <-ctx.Done():
98+
return
99+
default:
100+
data := make([]byte, 10240)
101+
n, err := u.conn.Read(data)
102+
if err != nil && err != io.EOF {
103+
u.exit <- err
104+
return
105+
}
106+
u.read <- data[:n]
91107
}
92-
u.read <- data[:n]
93108
}
94109
}
95110

96111
// 将数据写给User端
97-
func (u *user) Write() {
112+
func (u *user) Write(ctx context.Context) {
98113
for {
99114
select {
115+
case <-ctx.Done():
116+
return
100117
case data := <-u.write:
101118
_, err := u.conn.Write(data)
102119
if err != nil && err != io.EOF {
103120
u.exit <- err
121+
return
104122
}
105123
}
106124
}
@@ -156,36 +174,47 @@ func main() {
156174
}
157175

158176
func HandleClient(client *client, userConnChan chan net.Conn) {
177+
ctx, cancel := context.WithCancel(context.Background())
178+
179+
go client.Read(ctx)
180+
go client.Write(ctx)
159181

160-
go client.Read()
161-
go client.Write()
182+
user := &user{
183+
read: make(chan []byte),
184+
write: make(chan []byte),
185+
exit: make(chan error),
186+
}
187+
188+
defer func() {
189+
_ = client.conn.Close()
190+
_ = user.conn.Close()
191+
client.reConn <- true
192+
}()
162193

163194
for {
164195
select {
165-
case err := <-client.exit:
166-
fmt.Printf("client出现错误, 开始重试, err: %s \n", err.Error())
167-
client.reConn <- true
168-
runtime.Goexit()
169-
170196
case userConn := <-userConnChan:
171-
user := &user{
172-
conn: userConn,
173-
read: make(chan []byte),
174-
write: make(chan []byte),
175-
exit: make(chan error),
176-
}
177-
go user.Read()
178-
go user.Write()
179-
180-
go handle(client, user)
197+
user.conn = userConn
198+
go handle(ctx, client, user)
199+
case err := <-client.exit:
200+
fmt.Println("client出现错误, 关闭连接", err.Error())
201+
cancel()
202+
return
203+
case err := <-user.exit:
204+
fmt.Println("user出现错误,关闭连接", err.Error())
205+
cancel()
206+
return
181207
}
182208
}
183209
}
184210

185211
// 将两个Socket通道链接
186212
// 1. 将从user收到的信息发给client
187213
// 2. 将从client收到信息发给user
188-
func handle(client *client, user *user) {
214+
func handle(ctx context.Context, client *client, user *user) {
215+
go user.Read(ctx)
216+
go user.Write(ctx)
217+
189218
for {
190219
select {
191220
case userRecv := <-user.read:
@@ -195,17 +224,8 @@ func handle(client *client, user *user) {
195224
// 收到从client发来的信息
196225
user.write <- clientRecv
197226

198-
case err := <-client.exit:
199-
fmt.Println("client出现错误, 关闭连接", err.Error())
200-
_ = client.conn.Close()
201-
_ = user.conn.Close()
202-
client.reConn <- true
203-
// 结束当前goroutine
204-
runtime.Goexit()
205-
206-
case err := <-user.exit:
207-
fmt.Println("user出现错误,关闭连接", err.Error())
208-
_ = user.conn.Close()
227+
case <-ctx.Done():
228+
return
209229
}
210230
}
211231
}

0 commit comments

Comments
 (0)