diff --git a/WelsonJS.Toolkit/WelsonJS.Launcher/WebSocketManager.cs b/WelsonJS.Toolkit/WelsonJS.Launcher/WebSocketManager.cs index c5f9a8f..1e7ef34 100644 --- a/WelsonJS.Toolkit/WelsonJS.Launcher/WebSocketManager.cs +++ b/WelsonJS.Toolkit/WelsonJS.Launcher/WebSocketManager.cs @@ -2,9 +2,11 @@ // SPDX-License-Identifier: GPL-3.0-or-later // SPDX-FileCopyrightText: 2025 Catswords OSS and WelsonJS Contributors // https://github.com/gnh1201/welsonjs -// +// using System; +using System.Buffers; using System.Collections.Concurrent; +using System.IO; using System.Net.WebSockets; using System.Security.Cryptography; using System.Text; @@ -21,6 +23,8 @@ namespace WelsonJS.Launcher public string Host; public int Port; public string Path; + // Ensures that send/receive is serialized per socket + public readonly SemaphoreSlim IoLock = new SemaphoreSlim(1, 1); } private readonly ConcurrentDictionary _pool = new ConcurrentDictionary(); @@ -36,33 +40,36 @@ namespace WelsonJS.Launcher } } - // Get an open WebSocket or connect a new one - public async Task GetOrCreateAsync(string host, int port, string path) + // Get an existing open WebSocket entry or create a new one + private async Task GetOrCreateAsync(string host, int port, string path) { string key = MakeKey(host, port, path); if (_pool.TryGetValue(key, out var entry)) { var sock = entry.Socket; + if (sock != null && sock.State == WebSocketState.Open) + return entry; - if (sock == null || sock.State != WebSocketState.Open) - { - Remove(host, port, path); - } - else - { - return sock; - } + Remove(host, port, path); } var newSock = new ClientWebSocket(); - var uri = new Uri($"ws://{host}:{port}/{path}"); + + // Build the WebSocket URI safely + var ub = new UriBuilder + { + Scheme = "ws", + Host = host, + Port = port, + Path = string.IsNullOrEmpty(path) ? "/" : (path.StartsWith("/") ? path : "/" + path) + }; try { - await newSock.ConnectAsync(uri, CancellationToken.None); + await newSock.ConnectAsync(ub.Uri, CancellationToken.None); - _pool[key] = new Entry + var newEntry = new Entry { Socket = newSock, Host = host, @@ -70,7 +77,8 @@ namespace WelsonJS.Launcher Path = path }; - return newSock; + _pool[key] = newEntry; + return newEntry; } catch (Exception ex) { @@ -92,22 +100,27 @@ namespace WelsonJS.Launcher entry.Socket?.Dispose(); } catch { /* Ignore dispose exceptions */ } + finally + { + try { entry.IoLock?.Dispose(); } catch { } + } } } - // Send and receive with automatic retry on first failure - public async Task SendAndReceiveAsync(string host, int port, string path, string message, int timeoutSec) + // Send a message and receive a response, with automatic retry on first failure + public async Task SendAndReceiveAsync(string host, int port, string path, string message, int timeoutSec, int maxMessageBytes = 8 * 1024 * 1024) { - byte[] buf = Encoding.UTF8.GetBytes(message); var cts = timeoutSec > 0 ? new CancellationTokenSource(TimeSpan.FromSeconds(timeoutSec)) : new CancellationTokenSource(); + byte[] buf = Encoding.UTF8.GetBytes(message); + for (int attempt = 0; attempt < 2; attempt++) { try { - return await TrySendAndReceiveAsync(host, port, path, buf, cts.Token); + return await TrySendAndReceiveAsync(host, port, path, buf, cts.Token, maxMessageBytes); } catch { @@ -119,22 +132,65 @@ namespace WelsonJS.Launcher throw new InvalidOperationException("Unreachable"); } - // Actual send and receive implementation - private async Task TrySendAndReceiveAsync(string host, int port, string path, byte[] buf, CancellationToken token) + // Actual send/receive logic with full-frame accumulation until EndOfMessage + private async Task TrySendAndReceiveAsync(string host, int port, string path, byte[] sendBuf, CancellationToken token, int maxMessageBytes) { try { - var sock = await GetOrCreateAsync(host, port, path); + var entry = await GetOrCreateAsync(host, port, path); + var sock = entry.Socket; if (sock.State != WebSocketState.Open) throw new WebSocketException("WebSocket is not in an open state"); - await sock.SendAsync(new ArraySegment(buf), WebSocketMessageType.Text, true, token); + await entry.IoLock.WaitAsync(token); + try + { + // Send message (single-frame; can be split if needed) + await sock.SendAsync(new ArraySegment(sendBuf), WebSocketMessageType.Text, true, token); - byte[] recv = new byte[4096]; - var result = await sock.ReceiveAsync(new ArraySegment(recv), token); + // Receive message until EndOfMessage is reached + var buffer = ArrayPool.Shared.Rent(8192); + try + { + using (var ms = new MemoryStream()) + { + while (true) + { + var res = await sock.ReceiveAsync(new ArraySegment(buffer), token); - return Encoding.UTF8.GetString(recv, 0, result.Count); + if (res.MessageType == WebSocketMessageType.Close) + { + // Server requested closure + try { await sock.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing as requested by server", token); } catch { } + throw new WebSocketException($"WebSocket closed by server: {sock.CloseStatus} {sock.CloseStatusDescription}"); + } + + if (res.Count > 0) + { + ms.Write(buffer, 0, res.Count); + + if (ms.Length > maxMessageBytes) + throw new InvalidOperationException($"Received message exceeds limit ({maxMessageBytes} bytes)."); + } + + if (res.EndOfMessage) + break; + } + + // Convert UTF-8 encoded text message to string + return Encoding.UTF8.GetString(ms.ToArray()); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + finally + { + entry.IoLock.Release(); + } } catch (WebSocketException ex) {