21#include <freerdp/log.h>
24#define TAG FREERDP_TAG("core.gateway.websocket")
26struct s_websocket_context
33 BYTE fragmentOriginalOpcode;
34 BYTE lengthAndMaskPosition;
35 WEBSOCKET_STATE state;
39static int websocket_write_all(BIO* bio,
const BYTE* data,
size_t length);
41BOOL websocket_context_mask_and_send(BIO* bio,
wStream* sPacket,
wStream* sDataPacket,
44 const size_t len = Stream_Length(sDataPacket);
45 Stream_SetPosition(sDataPacket, 0);
47 if (!Stream_EnsureRemainingCapacity(sPacket, len))
52 for (; streamPos + 4 <= len; streamPos += 4)
54 const uint32_t data = Stream_Get_UINT32(sDataPacket);
55 Stream_Write_UINT32(sPacket, data ^ maskingKey);
59 for (; streamPos < len; streamPos++)
62 BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
63 Stream_Read_UINT8(sDataPacket, data);
64 Stream_Write_UINT8(sPacket, data ^ *partialMask);
67 Stream_SealLength(sPacket);
70 const size_t size = Stream_Length(sPacket);
71 const int status = websocket_write_all(bio, Stream_Buffer(sPacket), size);
72 Stream_Free(sPacket, TRUE);
74 return !((status < 0) || ((
size_t)status != size));
77wStream* websocket_context_packet_new(
size_t len, WEBSOCKET_OPCODE opcode, UINT32* pMaskingKey)
79 WINPR_ASSERT(pMaskingKey);
86 else if (len < 0x10000)
91 UINT32 maskingKey = 0;
92 if (winpr_RAND(&maskingKey,
sizeof(maskingKey)) < 0)
95 wStream* sWS = Stream_New(
nullptr, fullLen);
99 Stream_Write_UINT8(sWS, (UINT8)(WEBSOCKET_FIN_BIT | opcode));
101 Stream_Write_UINT8(sWS, (UINT8)len | WEBSOCKET_MASK_BIT);
102 else if (len < 0x10000)
104 Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
105 Stream_Write_UINT16_BE(sWS, (UINT16)len);
109 Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
110 Stream_Write_UINT32_BE(sWS, 0);
111 Stream_Write_UINT32_BE(sWS, (UINT32)len);
113 Stream_Write_UINT32(sWS, maskingKey);
114 *pMaskingKey = maskingKey;
118BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio,
wStream* sPacket,
119 WEBSOCKET_OPCODE opcode)
121 WINPR_ASSERT(context);
123 if (context->closeSent)
126 if (opcode == WebsocketCloseOpcode)
127 context->closeSent = TRUE;
130 WINPR_ASSERT(sPacket);
132 const size_t len = Stream_Length(sPacket);
133 uint32_t maskingKey = 0;
134 wStream* sWS = websocket_context_packet_new(len, opcode, &maskingKey);
138 return websocket_context_mask_and_send(bio, sWS, sPacket, maskingKey);
141int websocket_write_all(BIO* bio,
const BYTE* data,
size_t length)
147 if (length > INT32_MAX)
150 while (offset < length)
153 const size_t diff = length - offset;
154 int status = BIO_write(bio, &data[offset], (
int)diff);
157 offset += (size_t)status;
160 if (!BIO_should_retry(bio))
163 if (BIO_write_blocked(bio))
165 const long rstatus = BIO_wait_write(bio, 100);
169 else if (BIO_read_blocked(bio))
179int websocket_context_write(websocket_context* context, BIO* bio,
const BYTE* buf,
int isize,
180 WEBSOCKET_OPCODE opcode)
188 wStream sbuffer = WINPR_C_ARRAY_INIT;
189 wStream* s = Stream_StaticConstInit(&sbuffer, buf, (
size_t)isize);
190 if (!websocket_context_write_wstream(context, bio, s, opcode))
195static int websocket_read_data(BIO* bio, BYTE* pBuffer,
size_t size,
196 websocket_context* encodingContext)
201 WINPR_ASSERT(pBuffer);
202 WINPR_ASSERT(encodingContext);
204 if (encodingContext->payloadLength == 0)
206 encodingContext->state = WebsocketStateOpcodeAndFin;
211 (encodingContext->payloadLength < size ? encodingContext->payloadLength : size);
212 if (rlen > INT32_MAX)
216 status = BIO_read(bio, pBuffer, (
int)rlen);
217 if ((status <= 0) || ((
size_t)status > encodingContext->payloadLength))
220 encodingContext->payloadLength -= (size_t)status;
222 if (encodingContext->payloadLength == 0)
223 encodingContext->state = WebsocketStateOpcodeAndFin;
228static int websocket_read_wstream(BIO* bio, websocket_context* encodingContext)
231 WINPR_ASSERT(encodingContext);
233 wStream* s = encodingContext->responseStreamBuffer;
236 if (encodingContext->payloadLength == 0)
238 encodingContext->state = WebsocketStateOpcodeAndFin;
242 if (!Stream_EnsureRemainingCapacity(s, encodingContext->payloadLength))
245 "wStream::capacity [%" PRIuz
"] != encodingContext::paylaodLangth [%" PRIuz
"]",
246 Stream_GetRemainingCapacity(s), encodingContext->payloadLength);
250 const int status = websocket_read_data(bio, Stream_Pointer(s), Stream_GetRemainingCapacity(s),
255 if (!Stream_SafeSeek(s, (
size_t)status))
261static BOOL websocket_reply_close(BIO* bio, websocket_context* context,
wStream* s)
265 return websocket_context_write_wstream(context, bio, s, WebsocketCloseOpcode);
268static BOOL websocket_reply_pong(BIO* bio, websocket_context* context,
wStream* s)
273 if (Stream_GetPosition(s) != 0)
274 return websocket_context_write_wstream(context, bio, s, WebsocketPongOpcode);
276 return websocket_reply_close(bio, context,
nullptr);
279static int websocket_handle_payload(BIO* bio, BYTE* pBuffer,
size_t size,
280 websocket_context* encodingContext)
285 WINPR_ASSERT(pBuffer);
286 WINPR_ASSERT(encodingContext);
288 const BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
289 ? encodingContext->fragmentOriginalOpcode & 0xf
290 : encodingContext->opcode & 0xf);
292 switch (effectiveOpcode)
294 case WebsocketBinaryOpcode:
296 status = websocket_read_data(bio, pBuffer, size, encodingContext);
302 case WebsocketPingOpcode:
304 status = websocket_read_wstream(bio, encodingContext);
308 if (encodingContext->payloadLength == 0)
310 websocket_reply_pong(bio, encodingContext, encodingContext->responseStreamBuffer);
311 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
315 case WebsocketPongOpcode:
317 status = websocket_read_wstream(bio, encodingContext);
321 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
324 case WebsocketCloseOpcode:
326 status = websocket_read_wstream(bio, encodingContext);
330 if (encodingContext->payloadLength == 0)
332 websocket_reply_close(bio, encodingContext, encodingContext->responseStreamBuffer);
333 encodingContext->closeSent = TRUE;
334 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
339 WLog_WARN(TAG,
"Unimplemented websocket opcode %" PRIx8
". Dropping", effectiveOpcode);
341 status = websocket_read_wstream(bio, encodingContext);
344 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
352int websocket_context_read(websocket_context* encodingContext, BIO* bio, BYTE* pBuffer,
size_t size)
355 size_t effectiveDataLen = 0;
358 WINPR_ASSERT(pBuffer);
359 WINPR_ASSERT(encodingContext);
363 switch (encodingContext->state)
365 case WebsocketStateOpcodeAndFin:
367 BYTE buffer[1] = WINPR_C_ARRAY_INIT;
370 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
372 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(
int, effectiveDataLen)
375 encodingContext->opcode = buffer[0];
376 if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) &&
377 (encodingContext->opcode & 0xf) < 0x08)
378 encodingContext->fragmentOriginalOpcode = encodingContext->opcode;
379 encodingContext->state = WebsocketStateLengthAndMasking;
382 case WebsocketStateLengthAndMasking:
384 BYTE buffer[1] = WINPR_C_ARRAY_INIT;
387 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
389 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(
int, effectiveDataLen)
392 encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
393 encodingContext->lengthAndMaskPosition = 0;
394 encodingContext->payloadLength = 0;
395 const BYTE len = buffer[0] & 0x7f;
398 encodingContext->payloadLength = len;
399 encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey
400 : WebSocketStatePayload);
403 encodingContext->state = WebsocketStateShortLength;
405 encodingContext->state = WebsocketStateLongLength;
408 case WebsocketStateShortLength:
409 case WebsocketStateLongLength:
411 BYTE buffer[1] = WINPR_C_ARRAY_INIT;
412 const BYTE lenLength =
413 (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
414 while (encodingContext->lengthAndMaskPosition < lenLength)
417 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
419 return (effectiveDataLen > 0
420 ? WINPR_ASSERTING_INT_CAST(
int, effectiveDataLen)
422 if (status > UINT8_MAX)
424 encodingContext->payloadLength =
425 (encodingContext->payloadLength) << 8 | buffer[0];
426 encodingContext->lengthAndMaskPosition +=
427 WINPR_ASSERTING_INT_CAST(BYTE, status);
429 encodingContext->state =
430 (encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload);
433 case WebSocketStateMaskingKey:
436 TAG,
"Websocket Server sends data with masking key. This is against RFC 6455.");
439 case WebSocketStatePayload:
441 status = websocket_handle_payload(bio, pBuffer, size, encodingContext);
443 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(
int, effectiveDataLen)
446 effectiveDataLen += WINPR_ASSERTING_INT_CAST(
size_t, status);
448 if (WINPR_ASSERTING_INT_CAST(
size_t, status) >= size)
449 return WINPR_ASSERTING_INT_CAST(
int, effectiveDataLen);
451 size -= WINPR_ASSERTING_INT_CAST(
size_t, status);
461websocket_context* websocket_context_new(
void)
463 websocket_context* context = calloc(1,
sizeof(websocket_context));
467 context->responseStreamBuffer = Stream_New(
nullptr, 1024);
468 if (!context->responseStreamBuffer)
471 if (!websocket_context_reset(context))
476 websocket_context_free(context);
480void websocket_context_free(websocket_context* context)
485 Stream_Free(context->responseStreamBuffer, TRUE);
489BOOL websocket_context_reset(websocket_context* context)
491 WINPR_ASSERT(context);
493 context->state = WebsocketStateOpcodeAndFin;
494 return Stream_SetPosition(context->responseStreamBuffer, 0);