diff --git a/websocket/clientHandler.go b/websocket/clientHandler.go index 70f19ed..7379970 100644 --- a/websocket/clientHandler.go +++ b/websocket/clientHandler.go @@ -17,6 +17,12 @@ type ClientHandler struct { Clients models.Clients } +func SendBroadcast(msg []byte) { + for _, c := range models.Broadcast { + c.SendResponse(msg) + } +} + // initaiates new conections with client map func NewConnectionHandler() *ClientHandler { return &ClientHandler{ @@ -41,7 +47,6 @@ func (cH *ClientHandler) ConnectNewClient(id string, c *gin.Context) (client *mo cH.Lock() cH.Clients[id] = client cH.Unlock() - return client, nil } diff --git a/websocket/models/client.go b/websocket/models/client.go index eaf271c..031432f 100644 --- a/websocket/models/client.go +++ b/websocket/models/client.go @@ -1,10 +1,10 @@ package models import ( - "encoding/json" "fmt" "log" "net/http" + "slices" "time" "github.com/gin-gonic/gin" @@ -13,6 +13,8 @@ import ( var Origins []string = []string{"*"} +var Broadcast Clients = make(Clients) + const ( // Time allowed to write a message to the peer. writeWait = 10 * time.Second @@ -34,7 +36,6 @@ type Client struct { OnError func(err error) OnPing func() OnPong func() - sendPong chan string send chan []byte unregister chan []byte } @@ -43,17 +44,10 @@ var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { if len(Origins) == 0 { return false - } - if Origins[0] == "*" { + } else if Origins[0] == "*" { return true } - origin := r.Header.Get("Origin") - for _, o := range Origins { - if o == origin { - return true - } - } - return false + return slices.Contains(Origins, r.Header.Get("Origin")) }, EnableCompression: false, } @@ -68,22 +62,25 @@ func ConnectNewClient(id string, c *gin.Context) (*Client, error) { Id: id, Connected: true, conn: conn, - sendPong: make(chan string), - send: make(chan []byte), - unregister: make(chan []byte), + send: make(chan []byte, 512), + unregister: make(chan []byte, 256), } + Broadcast[client.Id] = client + conn.SetPingHandler(func(appData string) error { if client.OnPing != nil { client.OnPing() } conn.SetWriteDeadline(time.Now().Add(writeWait)) conn.SetReadDeadline(time.Now().Add(writeWait)) - client.sendPong <- appData + if err := client.conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(pongWait)); err != nil { + client.OnError(err) + } return nil }) - conn.SetPongHandler(func(string) error { + conn.SetPongHandler(func(appData string) error { conn.SetReadDeadline(time.Now().Add(pongWait)) if client.OnPong != nil { client.OnPong() @@ -94,6 +91,7 @@ func ConnectNewClient(id string, c *gin.Context) (*Client, error) { // Start reading messages from client go client.Read() go client.Write() + go client.PingInterval(pingPeriod) return client, nil } @@ -105,6 +103,7 @@ func (c *Client) Read() { c.conn.SetReadDeadline(time.Now().Add(writeWait)) for c.Connected { msgType, msg, err := c.conn.ReadMessage() + if err != nil { c.handleError(fmt.Errorf("read error (id:%s): %w", c.Id, err)) return @@ -114,12 +113,10 @@ func (c *Client) Read() { c.Close(websocket.CloseNormalClosure, "Client closed") return case websocket.TextMessage: - if isPing := c.handleJsonPing(msg); !isPing { - if c.OnMessage != nil { - c.OnMessage(msg) - } else { - log.Printf("Received message but no handler set (id:%s): %s", c.Id, string(msg)) - } + if c.OnMessage != nil { + c.OnMessage(msg) + } else { + log.Printf("Received message but no handler set (id:%s): %s", c.Id, string(msg)) } default: log.Printf("Unhandled message type %d (id:%s)", msgType, c.Id) @@ -128,19 +125,15 @@ func (c *Client) Read() { } func (c *Client) Write() { - ticker := time.NewTicker(pingPeriod) - defer func() { - ticker.Stop() - c.conn.Close() - }() - for { + defer c.conn.Close() + for { select { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // The hub closed the channel. - if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil { + if err := c.conn.WriteMessage(websocket.CloseMessage, []byte("ping")); err != nil { c.handleError(err) return } @@ -152,26 +145,11 @@ func (c *Client) Write() { return } } - case <-ticker.C: - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { - c.handleError(err) - return - } - - if c.OnPing != nil { - c.OnPing() - } - - case message, ok := <-c.sendPong: - if ok { - c.conn.WriteMessage(websocket.PongMessage, []byte(message)) - } case message := <-c.unregister: c.conn.WriteMessage(websocket.CloseMessage, message) c.Connected = false - close(c.sendPong) close(c.send) + delete(Broadcast, c.Id) close(c.unregister) return } @@ -179,31 +157,20 @@ func (c *Client) Write() { } } -func (c *Client) handleJsonPing(msg []byte) (isPing bool) { - var wsMsg WSMessage - err := json.Unmarshal(msg, &wsMsg) +func (c *Client) PingInterval(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() - if err == nil && wsMsg.IsPing() { - c.conn.SetReadDeadline(time.Now().Add(writeWait)) - // Respond with pong JSON - select { - case c.send <- GetPongByteSlice(): - default: - // optional: log or handle if send buffer is full - c.handleError(fmt.Errorf("failed to queue pong message")) - return - } - if err != nil { - c.handleError(fmt.Errorf("write pong error: %w", err)) - return - } + for range ticker.C { if c.OnPing != nil { c.OnPing() } - isPing = true - } - return + if err := c.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(pongWait)); err != nil { + c.OnError(err) + return + } + } } func (c *Client) SendResponse(data []byte) { @@ -215,6 +182,7 @@ func (c *Client) SendResponse(data []byte) { func (c *Client) Close(code int, reason string) error { closeMsg := websocket.FormatCloseMessage(code, reason) + select { case c.unregister <- closeMsg: // Attempt to send default: // If the channel is full, this runs @@ -230,6 +198,7 @@ func (c *Client) handleError(err error) { if c.OnError != nil { c.OnError(err) } + if err := c.Close(websocket.CloseInternalServerErr, err.Error()); err != nil { if c.OnError != nil { c.OnError(err)