FreeRDP
Loading...
Searching...
No Matches
websocket.c
1
20#include "websocket.h"
21#include <freerdp/log.h>
22#include "../tcp.h"
23
24#define TAG FREERDP_TAG("core.gateway.websocket")
25
26struct s_websocket_context
27{
28 size_t payloadLength;
29 uint32_t maskingKey;
30 BOOL masking;
31 BOOL closeSent;
32 BYTE opcode;
33 BYTE fragmentOriginalOpcode;
34 BYTE lengthAndMaskPosition;
35 WEBSOCKET_STATE state;
36 wStream* responseStreamBuffer;
37};
38
39static int websocket_write_all(BIO* bio, const BYTE* data, size_t length);
40
41BOOL websocket_context_mask_and_send(BIO* bio, wStream* sPacket, wStream* sDataPacket,
42 UINT32 maskingKey)
43{
44 const size_t len = Stream_Length(sDataPacket);
45 Stream_SetPosition(sDataPacket, 0);
46
47 if (!Stream_EnsureRemainingCapacity(sPacket, len))
48 return FALSE;
49
50 /* mask as much as possible with 32bit access */
51 size_t streamPos = 0;
52 for (; streamPos + 4 <= len; streamPos += 4)
53 {
54 const uint32_t data = Stream_Get_UINT32(sDataPacket);
55 Stream_Write_UINT32(sPacket, data ^ maskingKey);
56 }
57
58 /* mask the rest byte by byte */
59 for (; streamPos < len; streamPos++)
60 {
61 BYTE data = 0;
62 BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
63 Stream_Read_UINT8(sDataPacket, data);
64 Stream_Write_UINT8(sPacket, data ^ *partialMask);
65 }
66
67 Stream_SealLength(sPacket);
68
69 ERR_clear_error();
70 const size_t size = Stream_Length(sPacket);
71 const int status = websocket_write_all(bio, Stream_Buffer(sPacket), size);
72 Stream_Free(sPacket, TRUE);
73
74 return !((status < 0) || ((size_t)status != size));
75}
76
77wStream* websocket_context_packet_new(size_t len, WEBSOCKET_OPCODE opcode, UINT32* pMaskingKey)
78{
79 WINPR_ASSERT(pMaskingKey);
80 if (len > INT_MAX)
81 return nullptr;
82
83 size_t fullLen = 0;
84 if (len < 126)
85 fullLen = len + 6; /* 2 byte "mini header" + 4 byte masking key */
86 else if (len < 0x10000)
87 fullLen = len + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
88 else
89 fullLen = len + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
90
91 UINT32 maskingKey = 0;
92 if (winpr_RAND(&maskingKey, sizeof(maskingKey)) < 0)
93 return nullptr;
94
95 wStream* sWS = Stream_New(nullptr, fullLen);
96 if (!sWS)
97 return nullptr;
98
99 Stream_Write_UINT8(sWS, (UINT8)(WEBSOCKET_FIN_BIT | opcode));
100 if (len < 126)
101 Stream_Write_UINT8(sWS, (UINT8)len | WEBSOCKET_MASK_BIT);
102 else if (len < 0x10000)
103 {
104 Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
105 Stream_Write_UINT16_BE(sWS, (UINT16)len);
106 }
107 else
108 {
109 Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
110 Stream_Write_UINT32_BE(sWS, 0); /* payload is limited to INT_MAX */
111 Stream_Write_UINT32_BE(sWS, (UINT32)len);
112 }
113 Stream_Write_UINT32(sWS, maskingKey);
114 *pMaskingKey = maskingKey;
115 return sWS;
116}
117
118BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio, wStream* sPacket,
119 WEBSOCKET_OPCODE opcode)
120{
121 WINPR_ASSERT(context);
122
123 if (context->closeSent)
124 return FALSE;
125
126 if (opcode == WebsocketCloseOpcode)
127 context->closeSent = TRUE;
128
129 WINPR_ASSERT(bio);
130 WINPR_ASSERT(sPacket);
131
132 const size_t len = Stream_Length(sPacket);
133 uint32_t maskingKey = 0;
134 wStream* sWS = websocket_context_packet_new(len, opcode, &maskingKey);
135 if (!sWS)
136 return FALSE;
137
138 return websocket_context_mask_and_send(bio, sWS, sPacket, maskingKey);
139}
140
141int websocket_write_all(BIO* bio, const BYTE* data, size_t length)
142{
143 WINPR_ASSERT(bio);
144 WINPR_ASSERT(data);
145 size_t offset = 0;
146
147 if (length > INT32_MAX)
148 return -1;
149
150 while (offset < length)
151 {
152 ERR_clear_error();
153 const size_t diff = length - offset;
154 int status = BIO_write(bio, &data[offset], (int)diff);
155
156 if (status > 0)
157 offset += (size_t)status;
158 else
159 {
160 if (!BIO_should_retry(bio))
161 return -1;
162
163 if (BIO_write_blocked(bio))
164 {
165 const long rstatus = BIO_wait_write(bio, 100);
166 if (rstatus < 0)
167 return -1;
168 }
169 else if (BIO_read_blocked(bio))
170 return -2; /* Abort write, there is data that must be read */
171 else
172 USleep(100);
173 }
174 }
175
176 return (int)length;
177}
178
179int websocket_context_write(websocket_context* context, BIO* bio, const BYTE* buf, int isize,
180 WEBSOCKET_OPCODE opcode)
181{
182 WINPR_ASSERT(bio);
183 WINPR_ASSERT(buf);
184
185 if (isize < 0)
186 return -1;
187
188 wStream sbuffer = WINPR_C_ARRAY_INIT;
189 wStream* s = Stream_StaticConstInit(&sbuffer, buf, (size_t)isize);
190 if (!websocket_context_write_wstream(context, bio, s, opcode))
191 return -2;
192 return isize;
193}
194
195static int websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size,
196 websocket_context* encodingContext)
197{
198 int status = 0;
199
200 WINPR_ASSERT(bio);
201 WINPR_ASSERT(pBuffer);
202 WINPR_ASSERT(encodingContext);
203
204 if (encodingContext->payloadLength == 0)
205 {
206 encodingContext->state = WebsocketStateOpcodeAndFin;
207 return 0;
208 }
209
210 const size_t rlen =
211 (encodingContext->payloadLength < size ? encodingContext->payloadLength : size);
212 if (rlen > INT32_MAX)
213 return -1;
214
215 ERR_clear_error();
216 status = BIO_read(bio, pBuffer, (int)rlen);
217 if ((status <= 0) || ((size_t)status > encodingContext->payloadLength))
218 return status;
219
220 encodingContext->payloadLength -= (size_t)status;
221
222 if (encodingContext->payloadLength == 0)
223 encodingContext->state = WebsocketStateOpcodeAndFin;
224
225 return status;
226}
227
228static int websocket_read_wstream(BIO* bio, websocket_context* encodingContext)
229{
230 WINPR_ASSERT(bio);
231 WINPR_ASSERT(encodingContext);
232
233 wStream* s = encodingContext->responseStreamBuffer;
234 WINPR_ASSERT(s);
235
236 if (encodingContext->payloadLength == 0)
237 {
238 encodingContext->state = WebsocketStateOpcodeAndFin;
239 return 0;
240 }
241
242 if (!Stream_EnsureRemainingCapacity(s, encodingContext->payloadLength))
243 {
244 WLog_WARN(TAG,
245 "wStream::capacity [%" PRIuz "] != encodingContext::paylaodLangth [%" PRIuz "]",
246 Stream_GetRemainingCapacity(s), encodingContext->payloadLength);
247 return -1;
248 }
249
250 const int status = websocket_read_data(bio, Stream_Pointer(s), Stream_GetRemainingCapacity(s),
251 encodingContext);
252 if (status < 0)
253 return status;
254
255 if (!Stream_SafeSeek(s, (size_t)status))
256 return -1;
257
258 return status;
259}
260
261static BOOL websocket_reply_close(BIO* bio, websocket_context* context, wStream* s)
262{
263 WINPR_ASSERT(bio);
264
265 return websocket_context_write_wstream(context, bio, s, WebsocketCloseOpcode);
266}
267
268static BOOL websocket_reply_pong(BIO* bio, websocket_context* context, wStream* s)
269{
270 WINPR_ASSERT(bio);
271 WINPR_ASSERT(s);
272
273 if (Stream_GetPosition(s) != 0)
274 return websocket_context_write_wstream(context, bio, s, WebsocketPongOpcode);
275
276 return websocket_reply_close(bio, context, nullptr);
277}
278
279static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
280 websocket_context* encodingContext)
281{
282 int status = 0;
283
284 WINPR_ASSERT(bio);
285 WINPR_ASSERT(pBuffer);
286 WINPR_ASSERT(encodingContext);
287
288 const BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
289 ? encodingContext->fragmentOriginalOpcode & 0xf
290 : encodingContext->opcode & 0xf);
291
292 switch (effectiveOpcode)
293 {
294 case WebsocketBinaryOpcode:
295 {
296 status = websocket_read_data(bio, pBuffer, size, encodingContext);
297 if (status < 0)
298 return status;
299
300 return status;
301 }
302 case WebsocketPingOpcode:
303 {
304 status = websocket_read_wstream(bio, encodingContext);
305 if (status < 0)
306 return status;
307
308 if (encodingContext->payloadLength == 0)
309 {
310 websocket_reply_pong(bio, encodingContext, encodingContext->responseStreamBuffer);
311 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
312 }
313 }
314 break;
315 case WebsocketPongOpcode:
316 {
317 status = websocket_read_wstream(bio, encodingContext);
318 if (status < 0)
319 return status;
320 /* We don“t care about pong response data, discard. */
321 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
322 }
323 break;
324 case WebsocketCloseOpcode:
325 {
326 status = websocket_read_wstream(bio, encodingContext);
327 if (status < 0)
328 return status;
329
330 if (encodingContext->payloadLength == 0)
331 {
332 websocket_reply_close(bio, encodingContext, encodingContext->responseStreamBuffer);
333 encodingContext->closeSent = TRUE;
334 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
335 }
336 }
337 break;
338 default:
339 WLog_WARN(TAG, "Unimplemented websocket opcode %" PRIx8 ". Dropping", effectiveOpcode);
340
341 status = websocket_read_wstream(bio, encodingContext);
342 if (status < 0)
343 return status;
344 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
345 break;
346 }
347 /* return how many bytes have been written to pBuffer.
348 * Only WebsocketBinaryOpcode writes into it and it returns directly */
349 return 0;
350}
351
352int websocket_context_read(websocket_context* encodingContext, BIO* bio, BYTE* pBuffer, size_t size)
353{
354 int status = 0;
355 size_t effectiveDataLen = 0;
356
357 WINPR_ASSERT(bio);
358 WINPR_ASSERT(pBuffer);
359 WINPR_ASSERT(encodingContext);
360
361 while (TRUE)
362 {
363 switch (encodingContext->state)
364 {
365 case WebsocketStateOpcodeAndFin:
366 {
367 BYTE buffer[1] = WINPR_C_ARRAY_INIT;
368
369 ERR_clear_error();
370 status = BIO_read(bio, (char*)buffer, sizeof(buffer));
371 if (status <= 0)
372 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(int, effectiveDataLen)
373 : status);
374
375 encodingContext->opcode = buffer[0];
376 if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) &&
377 (encodingContext->opcode & 0xf) < 0x08)
378 encodingContext->fragmentOriginalOpcode = encodingContext->opcode;
379 encodingContext->state = WebsocketStateLengthAndMasking;
380 }
381 break;
382 case WebsocketStateLengthAndMasking:
383 {
384 BYTE buffer[1] = WINPR_C_ARRAY_INIT;
385
386 ERR_clear_error();
387 status = BIO_read(bio, (char*)buffer, sizeof(buffer));
388 if (status <= 0)
389 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(int, effectiveDataLen)
390 : status);
391
392 encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
393 encodingContext->lengthAndMaskPosition = 0;
394 encodingContext->payloadLength = 0;
395 const BYTE len = buffer[0] & 0x7f;
396 if (len < 126)
397 {
398 encodingContext->payloadLength = len;
399 encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey
400 : WebSocketStatePayload);
401 }
402 else if (len == 126)
403 encodingContext->state = WebsocketStateShortLength;
404 else
405 encodingContext->state = WebsocketStateLongLength;
406 }
407 break;
408 case WebsocketStateShortLength:
409 case WebsocketStateLongLength:
410 {
411 BYTE buffer[1] = WINPR_C_ARRAY_INIT;
412 const BYTE lenLength =
413 (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
414 while (encodingContext->lengthAndMaskPosition < lenLength)
415 {
416 ERR_clear_error();
417 status = BIO_read(bio, (char*)buffer, sizeof(buffer));
418 if (status <= 0)
419 return (effectiveDataLen > 0
420 ? WINPR_ASSERTING_INT_CAST(int, effectiveDataLen)
421 : status);
422 if (status > UINT8_MAX)
423 return -1;
424 encodingContext->payloadLength =
425 (encodingContext->payloadLength) << 8 | buffer[0];
426 encodingContext->lengthAndMaskPosition +=
427 WINPR_ASSERTING_INT_CAST(BYTE, status);
428 }
429 encodingContext->state =
430 (encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload);
431 }
432 break;
433 case WebSocketStateMaskingKey:
434 {
435 WLog_WARN(
436 TAG, "Websocket Server sends data with masking key. This is against RFC 6455.");
437 return -1;
438 }
439 case WebSocketStatePayload:
440 {
441 status = websocket_handle_payload(bio, pBuffer, size, encodingContext);
442 if (status < 0)
443 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(int, effectiveDataLen)
444 : status);
445
446 effectiveDataLen += WINPR_ASSERTING_INT_CAST(size_t, status);
447
448 if (WINPR_ASSERTING_INT_CAST(size_t, status) >= size)
449 return WINPR_ASSERTING_INT_CAST(int, effectiveDataLen);
450 pBuffer += status;
451 size -= WINPR_ASSERTING_INT_CAST(size_t, status);
452 }
453 break;
454 default:
455 break;
456 }
457 }
458 /* should be unreachable */
459}
460
461websocket_context* websocket_context_new(void)
462{
463 websocket_context* context = calloc(1, sizeof(websocket_context));
464 if (!context)
465 goto fail;
466
467 context->responseStreamBuffer = Stream_New(nullptr, 1024);
468 if (!context->responseStreamBuffer)
469 goto fail;
470
471 if (!websocket_context_reset(context))
472 goto fail;
473
474 return context;
475fail:
476 websocket_context_free(context);
477 return nullptr;
478}
479
480void websocket_context_free(websocket_context* context)
481{
482 if (!context)
483 return;
484
485 Stream_Free(context->responseStreamBuffer, TRUE);
486 free(context);
487}
488
489BOOL websocket_context_reset(websocket_context* context)
490{
491 WINPR_ASSERT(context);
492
493 context->state = WebsocketStateOpcodeAndFin;
494 return Stream_SetPosition(context->responseStreamBuffer, 0);
495}