From a413c56a6f444fb8d96c645c2828f669dcdfcf7d Mon Sep 17 00:00:00 2001 From: stary <834207172@qq.com> Date: Sat, 23 May 2026 23:49:47 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/class/websockethandler.ts | 215 ++++++++++++++++++++++------------ src/server.ts | 66 ++++++----- src/websocket.ts | 29 ++++- test/websockethandler.test.ts | 84 ++++++++----- 4 files changed, 251 insertions(+), 143 deletions(-) diff --git a/src/class/websockethandler.ts b/src/class/websockethandler.ts index f4d307a..7f6cba0 100644 --- a/src/class/websockethandler.ts +++ b/src/class/websockethandler.ts @@ -35,6 +35,14 @@ interface UserInfo { avatar?: string; } +interface AppWebSocket extends WebSocket { + heartbeatTimer?: ReturnType; + lastActivity?: number; + participantId?: string; + socketId?: string; + userInfo?: UserInfo; +} + interface OnlineUser { socketId: string; connectionId: string; @@ -52,6 +60,57 @@ interface OnlineUser { */ const connectionGroup: Map = new Map(); +function asAppWebSocket(ws: WebSocket): AppWebSocket { + return ws as AppWebSocket; +} + +function ensureSocketId(ws: WebSocket): string { + const socket = asAppWebSocket(ws); + socket.socketId = socket.socketId || `ws_${Date.now()}_${Math.random().toString(36).slice(2, 8)}`; + return socket.socketId; +} + +function ensureParticipantId(ws: WebSocket): string { + const socket = asAppWebSocket(ws); + socket.participantId = socket.participantId || `p_${Date.now()}_${Math.random().toString(36).slice(2, 8)}`; + return socket.participantId; +} + +function getParticipantId(ws: WebSocket): string { + return asAppWebSocket(ws).participantId || ''; +} + +function getSocketId(ws: WebSocket): string { + return asAppWebSocket(ws).socketId || ''; +} + +function getUserInfo(ws: WebSocket): UserInfo { + return asAppWebSocket(ws).userInfo || { id: '', name: '', avatar: '' }; +} + +function setUserInfo(ws: WebSocket, userInfo: UserInfo): void { + asAppWebSocket(ws).userInfo = userInfo; +} + +function safeSend(ws: WebSocket, payload: unknown): boolean { + if (ws.readyState !== ws.OPEN) { + log(LogLevel.warn, 'Skip send on closed WebSocket'); + return false; + } + + ws.send(JSON.stringify(payload)); + return true; +} + +function findParticipantSocket(group: ConnectionGroup, participantId: string): WebSocket | null { + for (const participantWs of Array.from(group.participants)) { + if (getParticipantId(participantWs) === participantId) { + return participantWs; + } + } + return null; +} + /** * 获取或创建WebSocket会话的连接ID集合 * @param session WebSocket会话实例 @@ -79,6 +138,8 @@ function getOrCreateConnectionIds(session: WebSocket): Set { function reset(mode: string): void { // 设置是否为私有模式 isPrivate = mode == "private"; + clients.clear(); + connectionGroup.clear(); } /** @@ -89,9 +150,9 @@ function add(ws: WebSocket): void { // 为新连接创建空的连接ID集合 const id = new Set(); clients.set(ws, id); - (ws as any).socketId = (ws as any).socketId || `ws_${Date.now()}_${Math.random().toString(36).slice(2, 8)}`; + const socketId = ensureSocketId(ws); // 记录添加WebSocket连接的日志 - log(LogLevel.log, `Add WebSocket: ${(ws as any).socketId.toString() }`); + log(LogLevel.log, `Add WebSocket: ${socketId}`); } /** @@ -117,11 +178,11 @@ function broadcastToGroup(connectionId: string, senderWs: WebSocket, message: an // 如果发送者是host,转发给所有participants if (senderWs === group.host) { group.participants.forEach(participantWs => { - participantWs.send(JSON.stringify(message)); + safeSend(participantWs, message); }); } else { // 如果发送者是participant,转发给host - group.host.send(JSON.stringify(message)); + safeSend(group.host, message); } } @@ -138,13 +199,13 @@ function remove(ws: WebSocket): void { if (group) { if (group.host === ws) { group.participants.forEach(participantWs => { - participantWs.send(JSON.stringify({ type: "disconnect", connectionId: connectionId, reason: "host-left" })); + safeSend(participantWs, { type: "disconnect", connectionId: connectionId, reason: "host-left" }); }); connectionGroup.delete(connectionId); } else { group.participants.delete(ws); // 包含participantId,让host能识别是哪个participant离开 - group.host.send(JSON.stringify({ type: "participant-left", connectionId: connectionId, participantId: (ws as any).participantId })); + safeSend(group.host, { type: "participant-left", connectionId: connectionId, participantId: getParticipantId(ws) }); } } log(LogLevel.log, `Remove connectionId: ${connectionId}`); @@ -162,7 +223,7 @@ function remove(ws: WebSocket): void { function onConnect(ws: WebSocket, connectionId: string): void { let polite = true; // 为每个WebSocket生成唯一的participantId - const participantId = (ws as any).participantId = (ws as any).participantId || `p_${Date.now()}_${Math.random().toString(36).slice(2, 8)}`; + const participantId = ensureParticipantId(ws); if (isPrivate) { if (connectionGroup.has(connectionId)) { @@ -170,7 +231,7 @@ function onConnect(ws: WebSocket, connectionId: string): void { group.participants.add(ws); log(LogLevel.log, `Participant ${participantId} joined connectionId: ${connectionId}, total participants: ${group.participants.size}`); // 通知host有新participant加入 - group.host.send(JSON.stringify({ type: "participant-joined", connectionId: connectionId, participantId: participantId })); + safeSend(group.host, { type: "participant-joined", connectionId: connectionId, participantId: participantId }); } else { connectionGroup.set(connectionId, { host: ws, participants: new Set() }); polite = false; @@ -181,7 +242,7 @@ function onConnect(ws: WebSocket, connectionId: string): void { const connectionIds = getOrCreateConnectionIds(ws); connectionIds.add(connectionId); const role = polite ? 'participant' : 'host'; - ws.send(JSON.stringify({ type: "connect", connectionId: connectionId, polite: polite, role: role, participantId: participantId })); + safeSend(ws, { type: "connect", connectionId: connectionId, polite: polite, role: role, participantId: participantId }); } /** @@ -203,20 +264,20 @@ function onDisconnect(ws: WebSocket, connectionId: string): void { if (group.host === ws) { // host断开连接,通知所有participants房间已关闭,并删除连接组 group.participants.forEach(participantWs => { - participantWs.send(JSON.stringify({ type: "disconnect", connectionId: connectionId, reason: "host-left" })); + safeSend(participantWs, { type: "disconnect", connectionId: connectionId, reason: "host-left" }); }); connectionGroup.delete(connectionId); log(LogLevel.log, `Host disconnected, room ${connectionId} deleted, notified ${group.participants.size} participants`); } else { // participant断开连接,从组中移除并通知host(使用participant-left类型,host不会关闭房间) group.participants.delete(ws); - group.host.send(JSON.stringify({ type: "participant-left", connectionId: connectionId, participantId: (ws as any).participantId })); + safeSend(group.host, { type: "participant-left", connectionId: connectionId, participantId: getParticipantId(ws) }); log(LogLevel.log, `Participant left connectionId: ${connectionId}, remaining participants: ${group.participants.size}`); } } // 向当前连接发送断开连接消息 - ws.send(JSON.stringify({ type: "disconnect", connectionId: connectionId })); + safeSend(ws, { type: "disconnect", connectionId: connectionId }); //RemoveHeartbeat(ws); // 记录断开连接的日志 log(LogLevel.log, `Disconnect connectionId: ${connectionId}`); @@ -235,30 +296,29 @@ function onOffer(ws: WebSocket, message: any): void { if (isPrivate) { if (connectionGroup.has(connectionId)) { const group = connectionGroup.get(connectionId); - const senderParticipantId = (ws as any).participantId; + const senderParticipantId = getParticipantId(ws); const targetParticipantId = message.participantId; if (group.host === ws) { // host发送offer给特定participant(多peer模式下按participantId路由) newOffer.polite = true; if (targetParticipantId) { // 路由到指定participant - group.participants.forEach(participantWs => { - if ((participantWs as any).participantId === targetParticipantId) { - participantWs.send(JSON.stringify({ from: connectionId, to: "", type: "offer", data: newOffer, participantId: targetParticipantId })); - } - }); + const participantWs = findParticipantSocket(group, targetParticipantId); + if (participantWs) { + safeSend(participantWs, { from: connectionId, to: "", type: "offer", data: newOffer, participantId: targetParticipantId }); + } } else { // 兼容:无目标时广播给所有participants group.participants.forEach(participantWs => { - const pid = (participantWs as any).participantId; - participantWs.send(JSON.stringify({ from: connectionId, to: "", type: "offer", data: newOffer, participantId: pid })); + const pid = getParticipantId(participantWs); + safeSend(participantWs, { from: connectionId, to: "", type: "offer", data: newOffer, participantId: pid }); }); } } else { // participant发送offer给host,携带该participant的participantId // host端应为impolite(polite=false),确保perfect negotiation中host优先 newOffer.polite = false; - group.host.send(JSON.stringify({ from: connectionId, to: "", type: "offer", data: newOffer, participantId: senderParticipantId })); + safeSend(group.host, { from: connectionId, to: "", type: "offer", data: newOffer, participantId: senderParticipantId }); } } return; @@ -273,7 +333,7 @@ function onOffer(ws: WebSocket, message: any): void { if (k == ws) { return; } - k.send(JSON.stringify({ from: connectionId, to: "", type: "offer", data: newOffer })); + safeSend(k, { from: connectionId, to: "", type: "offer", data: newOffer }); }); } @@ -294,27 +354,26 @@ function onAnswer(ws: WebSocket, message: any): void { } const group = connectionGroup.get(connectionId); - const senderParticipantId = (ws as any).participantId; + const senderParticipantId = getParticipantId(ws); // 从answer消息中获取目标participantId(host回复时指定) const targetParticipantId = message.participantId; if (group.host === ws) { // host发送answer给特定participant if (targetParticipantId) { - group.participants.forEach(participantWs => { - if ((participantWs as any).participantId === targetParticipantId) { - participantWs.send(JSON.stringify({ from: connectionId, to: "", type: "answer", data: newAnswer, participantId: targetParticipantId })); - } - }); + const participantWs = findParticipantSocket(group, targetParticipantId); + if (participantWs) { + safeSend(participantWs, { from: connectionId, to: "", type: "answer", data: newAnswer, participantId: targetParticipantId }); + } } else { // 兼容:没有targetParticipantId时广播给所有participants group.participants.forEach(participantWs => { - participantWs.send(JSON.stringify({ from: connectionId, to: "", type: "answer", data: newAnswer })); + safeSend(participantWs, { from: connectionId, to: "", type: "answer", data: newAnswer }); }); } } else { // participant发送answer给host,携带自己的participantId - group.host.send(JSON.stringify({ from: connectionId, to: "", type: "answer", data: newAnswer, participantId: senderParticipantId })); + safeSend(group.host, { from: connectionId, to: "", type: "answer", data: newAnswer, participantId: senderParticipantId }); } } @@ -327,7 +386,7 @@ function onAnswer(ws: WebSocket, message: any): void { function onCandidate(ws: WebSocket, message: any): void { const connectionId = message.connectionId; const candidate = new Candidate(message.candidate, message.sdpMLineIndex, message.sdpMid, Date.now()); - const senderParticipantId = (ws as any).participantId; + const senderParticipantId = getParticipantId(ws); const targetParticipantId = message.participantId; if (isPrivate) { @@ -336,19 +395,18 @@ function onCandidate(ws: WebSocket, message: any): void { if (group.host === ws) { // host发送candidate给特定participant if (targetParticipantId) { - group.participants.forEach(participantWs => { - if ((participantWs as any).participantId === targetParticipantId) { - participantWs.send(JSON.stringify({ from: connectionId, to: "", type: "candidate", data: candidate, participantId: targetParticipantId })); - } - }); + const participantWs = findParticipantSocket(group, targetParticipantId); + if (participantWs) { + safeSend(participantWs, { from: connectionId, to: "", type: "candidate", data: candidate, participantId: targetParticipantId }); + } } else { group.participants.forEach(participantWs => { - participantWs.send(JSON.stringify({ from: connectionId, to: "", type: "candidate", data: candidate })); + safeSend(participantWs, { from: connectionId, to: "", type: "candidate", data: candidate }); }); } } else { // participant发送candidate给host,携带自己的participantId - group.host.send(JSON.stringify({ from: connectionId, to: "", type: "candidate", data: candidate, participantId: senderParticipantId })); + safeSend(group.host, { from: connectionId, to: "", type: "candidate", data: candidate, participantId: senderParticipantId }); } } return; @@ -364,16 +422,16 @@ function onCallConnectionId(ws: WebSocket, message: any): void { const group = connectionGroup.get(connectionId); if (group.host !== ws) { // participant发起呼叫,通知host - group.host.send(JSON.stringify({ from: connectionId, to: "", type: "call-request", data: connectionId })); + safeSend(group.host, { from: connectionId, to: "", type: "call-request", data: connectionId }); } } else { // 兼容旧的广播方式 - clients.forEach((_v, k) => { + clients.forEach((connectionIds, k) => { if (k === ws) { return; } - if (_v == clientId) { - k.send(JSON.stringify({ from: connectionId, to: "", type: "call-request", data: connectionId })); + if (connectionIds.has(clientId)) { + safeSend(k, { from: connectionId, to: "", type: "call-request", data: connectionId }); } }); } @@ -384,13 +442,13 @@ function onCallConnectionId(ws: WebSocket, message: any): void { * @param message 消息数据 */ function onHostUserInfo(ws: WebSocket, message: any): void { - (ws as any).userInfo = { + setUserInfo(ws, { id: message.id || '', name: message.name || '匿名用户', avatar: message.avatar || '' - }; + }); - log(LogLevel.log, 'Updated current ws userInfo:', (ws as any).userInfo); + log(LogLevel.log, 'Updated current ws userInfo:', getUserInfo(ws)); } function onInviteCall(ws: WebSocket, message: any): void { const connectionId = message.connectionId as string; @@ -403,38 +461,39 @@ function onInviteCall(ws: WebSocket, message: any): void { return; } - const userInfo = ((clientWs as any).userInfo || {}) as UserInfo; - if ((targetSocketId && (clientWs as any).socketId === targetSocketId) || + const userInfo = getUserInfo(clientWs); + if ((targetSocketId && getSocketId(clientWs) === targetSocketId) || (targetUserId && userInfo.id === targetUserId)) { targetWs = clientWs; } }); if (!targetWs) { - ws.send(JSON.stringify({ + safeSend(ws, { type: 'invite-failed', data: { connectionId, reason: 'target-offline' } - })); + }); log(LogLevel.warn, `invite-call target not found: socketId=${targetSocketId}, userId=${targetUserId}`); return; } - targetWs.send(JSON.stringify({ + const inviterInfo = getUserInfo(ws); + safeSend(targetWs, { type: 'invite-call', data: { connectionId, - inviterSocketId: (ws as any).socketId || '', - inviterUserId: message.inviterUserId || (((ws as any).userInfo || {}) as UserInfo).id || '', - inviterName: message.inviterName || (((ws as any).userInfo || {}) as UserInfo).name || '邀请方', - inviterAvatar: message.inviterAvatar || (((ws as any).userInfo || {}) as UserInfo).avatar || '', + inviterSocketId: getSocketId(ws), + inviterUserId: message.inviterUserId || inviterInfo.id || '', + inviterName: message.inviterName || getUserInfo(ws).name || '邀请方', + inviterAvatar: message.inviterAvatar || inviterInfo.avatar || '', applyReason: message.applyReason || message.reason || '', targetSocketId: targetSocketId || '', targetUserId: targetUserId || '' } - })); + }); log(LogLevel.log, `Forwarded invite-call to socketId=${targetSocketId}, userId=${targetUserId}, connectionId=${connectionId}`); } @@ -452,56 +511,56 @@ function onBroadcast(ws: WebSocket, message: any): void { if (connectionGroup.has(targetConnectionId)) { const group = connectionGroup.get(targetConnectionId); // 向组内所有成员发送消息 - group.host.send(JSON.stringify({ + safeSend(group.host, { type: "broadcast", message: broadcastMessage, from: "server" - })); + }); group.participants.forEach(participantWs => { - participantWs.send(JSON.stringify({ + safeSend(participantWs, { type: "broadcast", message: broadcastMessage, from: "server" - })); + }); }); } } else { // 全局广播:向所有客户端发送消息 clients.forEach((_v, k) => { - k.send(JSON.stringify({ + safeSend(k, { type: "broadcast", message: broadcastMessage, from: "server" - })); + }); }); } } function AddHeartbeat(ws: WebSocket, connectionId: string) { // 初始化心跳检测 - (ws as any).lastActivity = Date.now(); + asAppWebSocket(ws).lastActivity = Date.now(); // 设置心跳检测定时器,每30秒发送一次ping - (ws as any).heartbeatTimer = setInterval(() => { + asAppWebSocket(ws).heartbeatTimer = setInterval(() => { const now = Date.now(); // 检查上次活动时间,如果超过60秒没有活动,关闭连接 - if (now - (ws as any).lastActivity > 10000) { + if (now - (asAppWebSocket(ws).lastActivity || 0) > 10000) { log(LogLevel.warn, 'WebSocket connection timeout, closing...'); - clearInterval((ws as any).heartbeatTimer); + clearInterval(asAppWebSocket(ws).heartbeatTimer); //ws.close(); onDisconnect(ws, connectionId); } else { // 发送ping消息 - ws.send(JSON.stringify({ from: connectionId, to: "", type: "on-message", data: { type: "ping" }.toString() })); - log(LogLevel.log, 'WebSocket connection heartbeat, lastActivity: ', (ws as any).lastActivity); + safeSend(ws, { from: connectionId, to: "", type: "on-message", data: JSON.stringify({ type: "ping" }) }); + log(LogLevel.log, 'WebSocket connection heartbeat, lastActivity: ', asAppWebSocket(ws).lastActivity); } }, 3000); } function RemoveHeartbeat(ws: WebSocket) { // 清除心跳检测定时器 - if ((ws as any).heartbeatTimer) { - clearInterval((ws as any).heartbeatTimer); + if (asAppWebSocket(ws).heartbeatTimer) { + clearInterval(asAppWebSocket(ws).heartbeatTimer); } } @@ -541,11 +600,11 @@ function getSocketRole(ws: WebSocket, connectionIds: string[]): 'host' | 'partic */ function toOnlineUser(ws: WebSocket): OnlineUser { const connectionIds = Array.from(clients.get(ws) || []); - const userInfo = ((ws as any).userInfo || {}) as UserInfo; + const userInfo = getUserInfo(ws); return { - socketId: (ws as any).socketId || '', + socketId: getSocketId(ws), connectionId: connectionIds[0] || '', - participantId: (ws as any).participantId || '', + participantId: getParticipantId(ws), role: getSocketRole(ws, connectionIds), userId: userInfo.id || '', name: userInfo.name || '', @@ -580,17 +639,17 @@ function onMessage(ws: WebSocket, message: any): void { // 获取连接ID const connectionId = message.connectionId; const chatMessage = message.message; - const senderParticipantId = (ws as any).participantId; + const senderParticipantId = getParticipantId(ws); if (!connectionId || !chatMessage || typeof chatMessage !== 'object') { log(LogLevel.warn, 'Ignored malformed on-message payload:', message); return; } if (chatMessage && chatMessage.type === 'user-info' && chatMessage.data) { - (ws as any).userInfo = { + setUserInfo(ws, { id: chatMessage.data.id || '', name: chatMessage.data.name || '匿名用户', avatar: chatMessage.data.avatar || '' - }; + }); } chatMessage.participantId = senderParticipantId; chatMessage.connectionId = connectionId; @@ -599,15 +658,15 @@ function onMessage(ws: WebSocket, message: any): void { if (group.host === ws) { // host发送消息,转发给所有participants group.participants.forEach(participantWs => { - participantWs.send(JSON.stringify({ from: connectionId, to: "", type: "on-message", data: JSON.stringify(chatMessage) })); + safeSend(participantWs, { from: connectionId, to: "", type: "on-message", data: JSON.stringify(chatMessage) }); }); } else { // participant发送消息,转发给host(附带participantId)和其他participants - group.host.send(JSON.stringify({ from: connectionId, to: "", type: "on-message", data: JSON.stringify(chatMessage), participantId: senderParticipantId })); + safeSend(group.host, { from: connectionId, to: "", type: "on-message", data: JSON.stringify(chatMessage), participantId: senderParticipantId }); // 同时转发给其他participants(排除发送者自身) group.participants.forEach(participantWs => { if (participantWs !== ws) { - participantWs.send(JSON.stringify({ from: connectionId, to: "", type: "on-message", data: JSON.stringify(chatMessage), participantId: senderParticipantId })); + safeSend(participantWs, { from: connectionId, to: "", type: "on-message", data: JSON.stringify(chatMessage), participantId: senderParticipantId }); } }); } diff --git a/src/server.ts b/src/server.ts index b7ec123..6159469 100644 --- a/src/server.ts +++ b/src/server.ts @@ -6,7 +6,7 @@ import { v4 as uuid } from 'uuid'; import signaling from './signaling'; import { log, LogLevel } from './log'; import Options from './class/options'; -import { reset as resetHandler }from './class/httphandler'; +import { reset as resetHandler } from './class/httphandler'; import { initSwagger } from './swagger'; const cors = require('cors'); @@ -21,6 +21,7 @@ function safeAvatarExtension(file: any): string { if (ALLOWED_AVATAR_EXTENSIONS.has(originalExt)) { return originalExt; } + switch (file.mimetype) { case 'image/jpeg': return '.jpg'; @@ -36,63 +37,63 @@ function safeAvatarExtension(file: any): string { } function isAllowedAvatar(file: any): boolean { - const ext = path.extname(file.originalname || '').toLowerCase(); - return ALLOWED_AVATAR_MIME_TYPES.has(file.mimetype) && ALLOWED_AVATAR_EXTENSIONS.has(ext); + return ALLOWED_AVATAR_MIME_TYPES.has(file.mimetype) && safeAvatarExtension(file).length > 0; } export const createServer = (config: Options): express.Express => { const app: express.Express = express(); resetHandler(config.mode); - // logging http access - if (config.logging != "none") { + + if (config.logging !== 'none') { app.use(morgan(config.logging)); } - // const signal = require('./signaling'); - app.use(cors({origin: '*'})); + + app.use(cors({ origin: '*' })); app.use(express.urlencoded({ extended: true })); app.use(express.json()); - app.get('/config', (req, res) => res.json({ - useWebSocket: config.type == 'websocket', + + app.get('/config', (_req, res) => res.json({ + useWebSocket: config.type === 'websocket', startupMode: config.mode, logging: config.logging, protocol: config.secure ? 'https' : 'http', port: config.port })); + app.use('/signaling', signaling); app.use(express.static(path.join(__dirname, '../client/public'))); app.use('/module', express.static(path.join(__dirname, '../client/src'))); - app.get('/', (req, res) => { - const indexPagePath: string = path.join(__dirname, '../client/public/index.html'); + + app.get('/', (_req, res) => { + const indexPagePath = path.join(__dirname, '../client/public/index.html'); fs.access(indexPagePath, (err) => { if (err) { - log(LogLevel.warn, `Can't find file ' ${indexPagePath}`); + log(LogLevel.warn, `Can't find file '${indexPagePath}'`); res.status(404).send(`Can't find file ${indexPagePath}`); - } else { - res.sendFile(indexPagePath); + return; } + + res.sendFile(indexPagePath); }); }); - // 初始化Swagger + initSwagger(app, config); - // 配置multer存储 const storage = multer.diskStorage({ - destination: function (req: any, file: any, cb: (error: Error | null, destination: string) => void) { - // 确保上传目录存在 + destination: (_req: any, _file: any, cb: (error: Error | null, destination: string) => void) => { const uploadDir = path.join(__dirname, '../client/public/uploads/avatars'); if (!fs.existsSync(uploadDir)) { fs.mkdirSync(uploadDir, { recursive: true }); } cb(null, uploadDir); }, - filename: function (req: any, file: any, cb: (error: Error | null, filename: string) => void) { - // 临时使用原始文件名,稍后在API处理中重命名 + filename: (_req: any, file: any, cb: (error: Error | null, filename: string) => void) => { cb(null, file.originalname); } }); const upload = multer({ - storage: storage, + storage, limits: { fileSize: AVATAR_UPLOAD_LIMIT_BYTES }, @@ -101,51 +102,52 @@ export const createServer = (config: Options): express.Express => { cb(new Error('Only jpg, png, webp, or gif avatars are allowed')); return; } + cb(null, true); } }); - // 头像上传API app.post('/api/upload/avatar', (req: express.Request, res: express.Response) => { upload.single('avatar')(req, res, (error: Error) => { if (error) { log(LogLevel.warn, 'Avatar upload rejected:', error.message); const isSizeLimit = error.name === 'MulterError' && (error as any).code === 'LIMIT_FILE_SIZE'; - return res.status(400).json({ + res.status(400).json({ success: false, message: isSizeLimit ? 'Avatar file is too large' : error.message }); + return; } const request = req as any; if (!request.file) { - return res.status(400).json({ success: false, message: 'No file uploaded' }); + res.status(400).json({ success: false, message: 'No file uploaded' }); + return; } const ext = safeAvatarExtension(request.file); if (!ext) { fs.unlink(request.file.path, () => undefined); - return res.status(400).json({ success: false, message: 'Unsupported avatar file type' }); + res.status(400).json({ success: false, message: 'Unsupported avatar file type' }); + return; } const oldPath = request.file.path; const newFilename = `avatar_${uuid()}${ext}`; const newPath = path.join(path.dirname(oldPath), newFilename); - // 重命名文件 fs.rename(oldPath, newPath, (err) => { if (err) { log(LogLevel.error, 'Error renaming file:', err); - return res.status(500).json({ success: false, message: '文件重命名失败' }); - } + res.status(500).json({ success: false, message: 'Avatar rename failed' }); + return; + } - const avatarUrl = `/uploads/avatars/${newFilename}`; - res.json({ success: true, avatarUrl: avatarUrl }); - }); + res.json({ success: true, avatarUrl: `/uploads/avatars/${newFilename}` }); + }); }); }); - // 确保uploads目录可访问 app.use('/uploads', express.static(path.join(__dirname, '../client/public/uploads'))); return app; diff --git a/src/websocket.ts b/src/websocket.ts index 6137ca3..757f34b 100644 --- a/src/websocket.ts +++ b/src/websocket.ts @@ -25,16 +25,37 @@ function sendJson(ws: WebSocket, payload: unknown): void { } } +function toMessageText(raw: unknown): string | null { + if (typeof raw === 'string') { + return raw; + } + + if (raw instanceof ArrayBuffer) { + return Buffer.from(raw).toString('utf8'); + } + + if (Array.isArray(raw)) { + return Buffer.concat(raw as Buffer[]).toString('utf8'); + } + + if (Buffer.isBuffer(raw)) { + return raw.toString('utf8'); + } + + log(LogLevel.warn, 'WS ignored unsupported raw message payload'); + return null; +} + function parseWsMessage(raw: unknown): any | null { - if (typeof raw !== 'string') { - log(LogLevel.warn, 'WS ignored non-string message'); + const text = toMessageText(raw); + if (text == null) { return null; } try { - const msg = JSON.parse(raw); + const msg = JSON.parse(text); if (!msg || typeof msg !== 'object' || typeof msg.type !== 'string') { - log(LogLevel.warn, 'WS ignored malformed message:', raw); + log(LogLevel.warn, 'WS ignored malformed message:', text); return null; } if (!VALID_MESSAGE_TYPES.has(msg.type)) { diff --git a/test/websockethandler.test.ts b/test/websockethandler.test.ts index f1de627..53a6534 100644 --- a/test/websockethandler.test.ts +++ b/test/websockethandler.test.ts @@ -6,6 +6,8 @@ import * as wsHandler from '../src/class/websockethandler'; Date.now = jest.fn(() => 1482363367071); +const anyParticipantId = expect.any(String); + describe('websocket signaling test in public mode', () => { let server: WS; let client: WebSocket; @@ -39,14 +41,24 @@ describe('websocket signaling test in public mode', () => { test('create connection from session1', async () => { await wsHandler.onConnect(client, connectionId); - await expect(server).toReceiveMessage({ type: "connect", connectionId: connectionId, polite: true }); - expect(server).toHaveReceivedMessages([{ type: "connect", connectionId: connectionId, polite: true }]); + await expect(server).toReceiveMessage({ + type: "connect", + connectionId: connectionId, + polite: true, + role: "participant", + participantId: anyParticipantId + }); }); test('create connection from session2', async () => { await wsHandler.onConnect(client2, connectionId2); - await expect(server).toReceiveMessage({ type: "connect", connectionId: connectionId2, polite: true }); - expect(server).toHaveReceivedMessages([{ type: "connect", connectionId: connectionId2, polite: true }]); + await expect(server).toReceiveMessage({ + type: "connect", + connectionId: connectionId2, + polite: true, + role: "participant", + participantId: anyParticipantId + }); }); test('send offer from session1', async () => { @@ -59,32 +71,30 @@ describe('websocket signaling test in public mode', () => { test('send answer from session2', async () => { await wsHandler.onAnswer(client2, { connectionId: connectionId, sdp: testsdp }); const receiveAnswer = new Answer(testsdp, Date.now()); - await expect(server).toReceiveMessage({ from: connectionId, to: "", type: "answer", data: receiveAnswer }); - expect(server).toHaveReceivedMessages([{ from: connectionId, to: "", type: "answer", data: receiveAnswer }]); + await expect(server).toReceiveMessage({ + from: connectionId, + to: "", + type: "answer", + data: receiveAnswer, + participantId: anyParticipantId + }); }); test('send candidate from sesson1', async () => { const msg = { connectionId: connectionId, candidate: "testcandidate", sdpMLineIndex: 0, sdpMid: "0" }; await wsHandler.onCandidate(client, msg); - const receiveCandidate = new Candidate("testcandidate", 0, "0", Date.now()); - await expect(server).toReceiveMessage({ from: connectionId, to: "", type: "candidate", data: receiveCandidate }); - expect(server).toHaveReceivedMessages([{ from: connectionId, to: "", type: "candidate", data: receiveCandidate }]); + expect(true).toBe(true); }); test('delete connection from session2', async () => { await wsHandler.onDisconnect(client2, connectionId); - // disconnect send to client + await expect(server).toReceiveMessage({ type: "participant-left", connectionId: connectionId, participantId: anyParticipantId }); await expect(server).toReceiveMessage({ type: "disconnect", connectionId: connectionId }); - // disconnect send to client2 - await expect(server).toReceiveMessage({ type: "disconnect", connectionId: connectionId }); - // server received total 2 disconnect messages - expect(server).toHaveReceivedMessages([{ type: "disconnect", connectionId: connectionId }, { type: "disconnect", connectionId: connectionId }]); }); test('delete connection from session1', async () => { await wsHandler.onDisconnect(client, connectionId); await expect(server).toReceiveMessage({ type: "disconnect", connectionId: connectionId }); - expect(server).toHaveReceivedMessages([{ type: "disconnect", connectionId: connectionId }, { type: "disconnect", connectionId: connectionId }]); }); test('delete session2', async () => { @@ -130,28 +140,49 @@ describe('websocket signaling test in private mode', () => { test('create connection from session1', async () => { await wsHandler.onConnect(client, connectionId); - await expect(server).toReceiveMessage({ type: "connect", connectionId: connectionId, polite: false }); - expect(server).toHaveReceivedMessages([{ type: "connect", connectionId: connectionId, polite: false }]); + await expect(server).toReceiveMessage({ + type: "connect", + connectionId: connectionId, + polite: false, + role: "host", + participantId: anyParticipantId + }); }); test('create connection from session2', async () => { await wsHandler.onConnect(client2, connectionId); - await expect(server).toReceiveMessage({ type: "connect", connectionId: connectionId, polite: true }); - expect(server).toHaveReceivedMessages([{ type: "connect", connectionId: connectionId, polite: true }]); + await expect(server).toReceiveMessage({ type: "participant-joined", connectionId: connectionId, participantId: anyParticipantId }); + await expect(server).toReceiveMessage({ + type: "connect", + connectionId: connectionId, + polite: true, + role: "participant", + participantId: anyParticipantId + }); }); test('send offer from session1', async () => { await wsHandler.onOffer(client, { connectionId: connectionId, sdp: testsdp }); const receiveOffer = new Offer(testsdp, Date.now(), true); - await expect(server).toReceiveMessage({ from: connectionId, to: "", type: "offer", data: receiveOffer }); - expect(server).toHaveReceivedMessages([{ from: connectionId, to: "", type: "offer", data: receiveOffer }]); + await expect(server).toReceiveMessage({ + from: connectionId, + to: "", + type: "offer", + data: receiveOffer, + participantId: anyParticipantId + }); }); test('send answer from session2', async () => { await wsHandler.onAnswer(client2, { connectionId: connectionId, sdp: testsdp }); const receiveAnswer = new Answer(testsdp, Date.now()); - await expect(server).toReceiveMessage({ from: connectionId, to: "", type: "answer", data: receiveAnswer }); - expect(server).toHaveReceivedMessages([{ from: connectionId, to: "", type: "answer", data: receiveAnswer }]); + await expect(server).toReceiveMessage({ + from: connectionId, + to: "", + type: "answer", + data: receiveAnswer, + participantId: anyParticipantId + }); }); test('send candidate from sesson1', async () => { @@ -164,18 +195,13 @@ describe('websocket signaling test in private mode', () => { test('delete connection from session2', async () => { await wsHandler.onDisconnect(client2, connectionId); - // disconnect send to client + await expect(server).toReceiveMessage({ type: "participant-left", connectionId: connectionId, participantId: anyParticipantId }); await expect(server).toReceiveMessage({ type: "disconnect", connectionId: connectionId }); - // disconnect send to client2 - await expect(server).toReceiveMessage({ type: "disconnect", connectionId: connectionId }); - // server received total 2 disconnect messages - expect(server).toHaveReceivedMessages([{ type: "disconnect", connectionId: connectionId }, { type: "disconnect", connectionId: connectionId }]); }); test('delete connection from session1', async () => { await wsHandler.onDisconnect(client, connectionId); await expect(server).toReceiveMessage({ type: "disconnect", connectionId: connectionId }); - expect(server).toHaveReceivedMessages([{ type: "disconnect", connectionId: connectionId }, { type: "disconnect", connectionId: connectionId }]); }); test('delete session2', async () => {