diff --git a/websocket/models/client.go b/websocket/models/client.go index 8e9659b..eaf271c 100644 --- a/websocket/models/client.go +++ b/websocket/models/client.go @@ -27,13 +27,14 @@ const ( type Client struct { Id string Connected bool `json:"connected"` - Conn *websocket.Conn `json:"-"` + conn *websocket.Conn `json:"-"` OnOpen func() OnMessage func(data []byte) OnClose func(code int, reason string) OnError func(err error) OnPing func() OnPong func() + sendPong chan string send chan []byte unregister chan []byte } @@ -66,7 +67,8 @@ func ConnectNewClient(id string, c *gin.Context) (*Client, error) { client := &Client{ Id: id, Connected: true, - Conn: conn, + conn: conn, + sendPong: make(chan string), send: make(chan []byte), unregister: make(chan []byte), } @@ -77,7 +79,8 @@ func ConnectNewClient(id string, c *gin.Context) (*Client, error) { } conn.SetWriteDeadline(time.Now().Add(writeWait)) conn.SetReadDeadline(time.Now().Add(writeWait)) - return conn.WriteMessage(websocket.PongMessage, []byte(appData)) + client.sendPong <- appData + return nil }) conn.SetPongHandler(func(string) error { @@ -99,9 +102,9 @@ func (c *Client) Read() { c.OnOpen() } - c.Conn.SetReadDeadline(time.Now().Add(writeWait)) + c.conn.SetReadDeadline(time.Now().Add(writeWait)) for c.Connected { - msgType, msg, err := c.Conn.ReadMessage() + msgType, msg, err := c.conn.ReadMessage() if err != nil { c.handleError(fmt.Errorf("read error (id:%s): %w", c.Id, err)) return @@ -128,30 +131,30 @@ func (c *Client) Write() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() - c.Conn.Close() + c.conn.Close() }() for { select { case message, ok := <-c.send: - c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // The hub closed the channel. - if err := c.Conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil { + if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil { c.handleError(err) return } c.handleError(fmt.Errorf("server %s closed channel", c.Id)) return } else { - if err := c.Conn.WriteMessage(websocket.TextMessage, message); err != nil { + if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { c.handleError(err) return } } case <-ticker.C: - c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { c.handleError(err) return } @@ -159,9 +162,15 @@ func (c *Client) Write() { if c.OnPing != nil { c.OnPing() } + + case message, ok := <-c.sendPong: + if ok { + c.conn.WriteMessage(websocket.PongMessage, []byte(message)) + } case message := <-c.unregister: - c.Conn.WriteMessage(websocket.CloseMessage, message) + c.conn.WriteMessage(websocket.CloseMessage, message) c.Connected = false + close(c.sendPong) close(c.send) close(c.unregister) return @@ -175,7 +184,7 @@ func (c *Client) handleJsonPing(msg []byte) (isPing bool) { err := json.Unmarshal(msg, &wsMsg) if err == nil && wsMsg.IsPing() { - c.Conn.SetReadDeadline(time.Now().Add(writeWait)) + c.conn.SetReadDeadline(time.Now().Add(writeWait)) // Respond with pong JSON select { case c.send <- GetPongByteSlice():