CCF
Loading...
Searching...
No Matches
tls_session.h
Go to the documentation of this file.
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the Apache 2.0 License.
3#pragma once
4
5#include "ccf/ds/logger.h"
6#include "ds/messaging.h"
7#include "ds/ring_buffer.h"
9#include "enclave/session.h"
10#include "tcp/msg_types.h"
11#include "tls/context.h"
12#include "tls/tls.h"
13
14#include <exception>
15
16namespace ccf
17{
27
28 class TLSSession : public std::enable_shared_from_this<TLSSession>
29 {
30 public:
31 using HandshakeErrorCB = std::function<void(std::string&&)>;
32
33 protected:
37
38 private:
39 std::vector<uint8_t> pending_write;
40 std::vector<uint8_t> pending_read;
41 // Decrypted data
42 std::vector<uint8_t> read_buffer;
43
44 std::unique_ptr<tls::Context> ctx;
45 SessionStatus status;
46
47 HandshakeErrorCB handshake_error_cb;
48
49 bool can_send()
50 {
51 // Closing endpoint should still be able to respond to clients (e.g. to
52 // report errors)
53 return status == ready || status == closing;
54 }
55
56 bool can_recv()
57 {
58 return status == ready || status == handshake;
59 }
60
61 struct SendRecvMsg
62 {
63 std::vector<uint8_t> data;
64 std::shared_ptr<TLSSession> self;
65 };
66
67 struct EmptyMsg
68 {
69 std::shared_ptr<TLSSession> self;
70 };
71
72 public:
74 int64_t session_id_,
75 ringbuffer::AbstractWriterFactory& writer_factory_,
76 std::unique_ptr<tls::Context> ctx_) :
77 to_host(writer_factory_.create_writer_to_outside()),
78 session_id(session_id_),
79 ctx(std::move(ctx_)),
80 status(handshake)
81 {
85 ctx->set_bio(this, send_callback_openssl, recv_callback_openssl);
86 }
87
88 virtual ~TLSSession()
89 {
90 RINGBUFFER_WRITE_MESSAGE(::tcp::tcp_closed, to_host, session_id);
91 }
92
94 {
95 return status;
96 }
97
98 void on_handshake_error(std::string&& error_msg)
99 {
100 if (handshake_error_cb)
101 {
102 handshake_error_cb(std::move(error_msg));
103 }
104 else
105 {
106 LOG_TRACE_FMT("{}", error_msg);
107 }
108 }
109
111 {
112 handshake_error_cb = std::move(cb);
113 }
114
115 std::string hostname()
116 {
117 if (status != ready)
118 {
119 return {};
120 }
121
122 return ctx->host();
123 }
124
125 std::vector<uint8_t> peer_cert()
126 {
127 return ctx->peer_cert();
128 }
129
130 // Returns count N of bytes read, which will be the first N bytes of data,
131 // up to a maximum of size. If exact is true, will only return either size
132 // or 0 (when size bytes are not currently available). data may be accessed
133 // beyond N during operation, up to size, but only the first N should be
134 // used by caller.
135 size_t read(uint8_t* data, size_t size, bool exact = false)
136 {
137 // This will return empty if the connection isn't
138 // ready, but it will not block on the handshake.
139 do_handshake();
140
141 if (status != ready)
142 {
143 LOG_TRACE_FMT("Not ready to read {} bytes", size);
144 return 0;
145 }
146
147 LOG_TRACE_FMT("Requesting up to {} bytes", size);
148
149 // Send pending writes.
150 flush();
151
152 size_t offset = 0;
153
154 if (read_buffer.size() > 0)
155 {
157 "Have existing read_buffer of size: {}", read_buffer.size());
158 offset = std::min(size, read_buffer.size());
159 ::memcpy(data, read_buffer.data(), offset);
160
161 if (offset < read_buffer.size())
162 read_buffer.erase(read_buffer.begin(), read_buffer.begin() + offset);
163 else
164 read_buffer.clear();
165
166 if (offset == size)
167 return size;
168
169 // NB: If we continue past here, read_buffer is empty
170 }
171
172 auto r = ctx->read(data + offset, size - offset);
173 LOG_TRACE_FMT("ctx->read returned: {}", r);
174
175 switch (r)
176 {
177 case 0:
179 {
181 "TLS {} close on read: {}", session_id, ::tls::error_string(r));
182
183 stop(closed);
184
185 if (!exact)
186 {
187 // Hit an error, but may still have some useful data from the
188 // previous read_buffer
189 return offset;
190 }
191
192 return 0;
193 }
194
197 {
198 if (!exact)
199 {
200 return offset;
201 }
202
203 // May have read something but not enough - copy it into read_buffer
204 // for next call
205 read_buffer.insert(read_buffer.end(), data, data + offset);
206 return 0;
207 }
208
209 default:
210 {
211 }
212 }
213
214 if (r < 0)
215 {
217 "TLS {} error on read: {}", session_id, ::tls::error_string(r));
218 stop(error);
219 return 0;
220 }
221
222 auto total = r + offset;
223
224 // We read _some_ data but not enough, and didn't get
225 // TLS_ERR_WANT_READ. Probably hit an internal size limit - try
226 // again
227 if (exact && (total < size))
228 {
230 "Asked for exactly {}, received {}, retrying", size, total);
231 read_buffer.insert(read_buffer.end(), data, data + total);
232 return read(data, size, exact);
233 }
234
235 return total;
236 }
237
238 void recv_buffered(const uint8_t* data, size_t size)
239 {
241 {
242 throw std::runtime_error("Called recv_buffered from incorrect thread");
243 }
244
245 if (can_recv())
246 {
247 pending_read.insert(pending_read.end(), data, data + size);
248 }
249
250 do_handshake();
251 }
252
253 void close()
254 {
255 status = closing;
257 {
258 auto msg = std::make_unique<::threading::Tmsg<EmptyMsg>>(&close_cb);
259 msg->data.self = this->shared_from_this();
260
262 execution_thread, std::move(msg));
263 }
264 else
265 {
266 // Close inline immediately
267 close_thread();
268 }
269 }
270
271 static void close_cb(std::unique_ptr<::threading::Tmsg<EmptyMsg>> msg)
272 {
273 msg->data.self->close_thread();
274 }
275
276 virtual void close_thread()
277 {
279 {
280 throw std::runtime_error("Called close_thread from incorrect thread");
281 }
282
283 switch (status)
284 {
285 case handshake:
286 {
287 LOG_TRACE_FMT("TLS {} closed during handshake", session_id);
288 stop(closed);
289 break;
290 }
291
292 case ready:
293 case closing:
294 {
295 int r = ctx->close();
296
297 switch (r)
298 {
301 {
302 LOG_TRACE_FMT("TLS {} has pending data ({})", session_id, r);
303 // FALLTHROUGH
304 }
305 case 0:
306 {
307 LOG_TRACE_FMT("TLS {} closed ({})", session_id, r);
308 stop(closed);
309 break;
310 }
311
312 default:
313 {
315 "TLS {} error on_close: {}",
318 stop(error);
319 break;
320 }
321 }
322 break;
323 }
324
325 default:
326 {
327 }
328 }
329 }
330
331 void send_raw(const uint8_t* data, size_t size)
332 {
334 {
335 auto msg =
336 std::make_unique<::threading::Tmsg<SendRecvMsg>>(&send_raw_cb);
337 msg->data.self = this->shared_from_this();
338 msg->data.data = std::vector<uint8_t>(data, data + size);
339
341 execution_thread, std::move(msg));
342 }
343 else
344 {
345 // Send inline immediately
346 send_raw_thread(data, size);
347 }
348 }
349
350 private:
351 static void send_raw_cb(std::unique_ptr<::threading::Tmsg<SendRecvMsg>> msg)
352 {
353 msg->data.self->send_raw_thread(
354 msg->data.data.data(), msg->data.data.size());
355 }
356
357 void send_raw_thread(const uint8_t* data, size_t size)
358 {
360 {
361 throw std::runtime_error(
362 "Called send_raw_thread from incorrect thread");
363 }
364 // Writes as much of the data as possible. If the data cannot all
365 // be written now, we store the remainder. We
366 // will try to send pending writes again whenever write() is called.
367 do_handshake();
368
369 if (status == handshake)
370 {
371 pending_write.insert(pending_write.end(), data, data + size);
372 return;
373 }
374
375 if (!can_send())
376 {
377 return;
378 }
379
380 pending_write.insert(pending_write.end(), data, data + size);
381
382 flush();
383 }
384
385 void send_buffered(const std::vector<uint8_t>& data)
386 {
388 {
389 throw std::runtime_error("Called send_buffered from incorrect thread");
390 }
391
392 pending_write.insert(pending_write.end(), data.begin(), data.end());
393 }
394
395 void flush()
396 {
398 {
399 throw std::runtime_error("Called flush from incorrect thread");
400 }
401
402 do_handshake();
403
404 if (!can_send())
405 {
406 return;
407 }
408
409 while (pending_write.size() > 0)
410 {
411 auto r = write_some(pending_write);
412
413 if (r > 0)
414 {
415 pending_write.erase(pending_write.begin(), pending_write.begin() + r);
416 }
417 else if (r == 0)
418 {
419 break;
420 }
421 else
422 {
423 LOG_TRACE_FMT("TLS session {} error on flush: {}", session_id, -r);
424 stop(error);
425 }
426 }
427 }
428
429 void do_handshake()
430 {
431 // This should be called when additional data is written to the
432 // input buffer, until the handshake is complete.
433 if (status != handshake)
434 {
435 return;
436 }
437
438 auto rc = ctx->handshake();
439
440 switch (rc)
441 {
442 case 0:
443 {
444 status = ready;
445 break;
446 }
447
450 break;
451
453 {
454 on_handshake_error(fmt::format(
455 "TLS {} verify error on handshake: {}",
458 stop(authfail);
459 break;
460 }
461
463 {
465 "TLS {} closed on handshake: {}",
468 stop(closed);
469 break;
470 }
471
473 {
474 auto err = ctx->get_verify_error();
475 on_handshake_error(fmt::format(
476 "TLS {} invalid cert on handshake: {} [{}]",
478 err,
480 stop(authfail);
481 return;
482 }
483
484 default:
485 {
486 on_handshake_error(fmt::format(
487 "TLS {} error on handshake: {}",
490 stop(error);
491 break;
492 }
493 }
494 }
495
496 int write_some(const std::vector<uint8_t>& data)
497 {
498 auto r = ctx->write(data.data(), data.size());
499
500 switch (r)
501 {
504 return 0;
505
506 default:
507 return r;
508 }
509 }
510
511 void stop(SessionStatus status_)
512 {
513 switch (status)
514 {
515 case closed:
516 case authfail:
517 case error:
518 return;
519
520 default:
521 {
522 }
523 }
524
525 status = status_;
526
527 switch (status)
528 {
529 case closing:
530 case closed:
531 {
533 ::tcp::tcp_stop,
534 to_host,
536 std::string("Session closed"));
537 break;
538 }
539
540 case authfail:
541 {
543 ::tcp::tcp_stop,
544 to_host,
546 std::string("Authentication failed"));
547 }
548 case error:
549 {
551 ::tcp::tcp_stop, to_host, session_id, std::string("Error"));
552 break;
553 }
554
555 default:
556 {
557 }
558 }
559 }
560
561 int handle_send(const uint8_t* buf, size_t len)
562 {
563 // Either write all of the data or none of it.
564 auto wrote = RINGBUFFER_TRY_WRITE_MESSAGE(
565 ::tcp::tcp_outbound,
566 to_host,
568 serializer::ByteRange{buf, len});
569
570 if (!wrote)
571 return TLS_WRITING;
572
573 return (int)len;
574 }
575
576 int handle_recv(uint8_t* buf, size_t len)
577 {
579 {
580 throw std::runtime_error("Called handle_recv from incorrect thread");
581 }
582 if (pending_read.size() > 0)
583 {
584 // Use the pending data vector. This is populated when the host
585 // writes a chunk larger than the size requested by the enclave.
586 size_t rd = std::min(len, pending_read.size());
587 ::memcpy(buf, pending_read.data(), rd);
588
589 if (rd >= pending_read.size())
590 {
591 pending_read.clear();
592 }
593 else
594 {
595 pending_read.erase(pending_read.begin(), pending_read.begin() + rd);
596 }
597
598 return (int)rd;
599 }
600
601 return TLS_READING;
602 }
603
604 static int send_callback(void* ctx, const unsigned char* buf, size_t len)
605 {
606 return reinterpret_cast<TLSSession*>(ctx)->handle_send(buf, len);
607 }
608
609 static int recv_callback(void* ctx, unsigned char* buf, size_t len)
610 {
611 return reinterpret_cast<TLSSession*>(ctx)->handle_recv(buf, len);
612 }
613
614 // These callbacks below are complex, using the callbacks above and
615 // manipulating OpenSSL's BIO objects accordingly. This is just so we can
616 // emulate what MbedTLS used to do.
617 // Now that we have removed it from the code, we can move the callbacks
618 // above to handle BIOs directly and hopefully remove the complexity below.
619 // This work will be carried out in #3429.
620 static long send_callback_openssl(
621 BIO* b,
622 int oper,
623 const char* argp,
624 size_t len,
625 int argi,
626 long argl,
627 int ret,
628 size_t* processed)
629 {
630 // Unused arguments
631 (void)argi;
632 (void)argl;
633 (void)argp;
634
635 if (ret && len > 0 && oper == (BIO_CB_WRITE | BIO_CB_RETURN))
636 {
637 // Flush BIO so the "pipe doesn't clog", but we don't use the
638 // data here, because 'argp' already has it.
639 BIO_flush(b);
640 size_t pending = BIO_pending(b);
641 if (pending)
642 BIO_reset(b);
643
644 // Pipe object
645 void* ctx = (BIO_get_callback_arg(b));
646 int put = send_callback(ctx, (const uint8_t*)argp, len);
647
648 // WANTS_WRITE
649 if (put == TLS_WRITING)
650 {
651 LOG_TRACE_FMT("TLS Session::send_cb() : WANTS_WRITE");
652 *processed = 0;
653 return -1;
654 }
655 else
656 {
657 LOG_TRACE_FMT("TLS Session::send_cb() : Put {} bytes", put);
658 }
659
660 // Update the number of bytes to external users
661 *processed = put;
662 }
663
664 // Unless we detected an error, the return value is always the same as the
665 // original operation.
666 return ret;
667 }
668
669 static long recv_callback_openssl(
670 BIO* b,
671 int oper,
672 const char* argp,
673 size_t len,
674 int argi,
675 long argl,
676 int ret,
677 size_t* processed)
678 {
679 // Unused arguments
680 (void)argi;
681 (void)argl;
682
683 if (ret == 1 && oper == (BIO_CB_CTRL | BIO_CB_RETURN))
684 {
685 // This callback may be fired at the end of large batches of TLS frames
686 // on OpenSSL 3.x. Note that processed == nullptr in this case, hence
687 // the early exit.
688 return 0;
689 }
690
691 if (ret && (oper == (BIO_CB_READ | BIO_CB_RETURN)))
692 {
693 // Pipe object
694 void* ctx = (BIO_get_callback_arg(b));
695 int got = recv_callback(ctx, (uint8_t*)argp, len);
696
697 // WANTS_READ
698 if (got == TLS_READING)
699 {
700 LOG_TRACE_FMT("TLS Session::recv_cb() : WANTS_READ");
701 *processed = 0;
702 return -1;
703 }
704 else
705 {
707 "TLS Session::recv_cb() : Got {} bytes of {}", got, len);
708 }
709
710 // If got less than requested, return WANT_READ
711 if ((size_t)got < len)
712 {
713 *processed = got;
714 return 1;
715 }
716
717 // Write to the actual BIO so SSL can use it
718 BIO_write_ex(b, argp, got, processed);
719
720 // The buffer should be enough, we can't return WANT_WRITE here
721 if ((size_t)got != *processed)
722 {
723 LOG_TRACE_FMT("TLS Session::recv_cb() : BIO error");
724 *processed = got;
725 return -1;
726 }
727
728 // If original return was -1 because it didn't find anything to read,
729 // return 1 to say we actually read something. This is common when the
730 // buffer is empty and needs an external read, so let's not log this.
731 if (got > 0 && ret < 0)
732 {
733 return 1;
734 }
735 }
736
737 // Unless we detected an error, the return value is always the same as the
738 // original operation.
739 return ret;
740 }
741 };
742}
Definition tls_session.h:29
SessionStatus get_status() const
Definition tls_session.h:93
void send_raw(const uint8_t *data, size_t size)
Definition tls_session.h:331
std::string hostname()
Definition tls_session.h:115
void recv_buffered(const uint8_t *data, size_t size)
Definition tls_session.h:238
virtual void close_thread()
Definition tls_session.h:276
std::function< void(std::string &&)> HandshakeErrorCB
Definition tls_session.h:31
size_t read(uint8_t *data, size_t size, bool exact=false)
Definition tls_session.h:135
TLSSession(int64_t session_id_, ringbuffer::AbstractWriterFactory &writer_factory_, std::unique_ptr< tls::Context > ctx_)
Definition tls_session.h:73
virtual ~TLSSession()
Definition tls_session.h:88
void on_handshake_error(std::string &&error_msg)
Definition tls_session.h:98
static void close_cb(std::unique_ptr<::threading::Tmsg< EmptyMsg > > msg)
Definition tls_session.h:271
void close()
Definition tls_session.h:253
std::vector< uint8_t > peer_cert()
Definition tls_session.h:125
ringbuffer::WriterPtr to_host
Definition tls_session.h:34
size_t execution_thread
Definition tls_session.h:36
void set_handshake_error_cb(HandshakeErrorCB &&cb)
Definition tls_session.h:110
::tcp::ConnID session_id
Definition tls_session.h:35
Definition ring_buffer_types.h:153
static ThreadMessaging & instance()
Definition thread_messaging.h:278
void add_task(uint16_t tid, std::unique_ptr< Tmsg< Payload > > msg)
Definition thread_messaging.h:312
uint16_t get_execution_thread(uint32_t i)
Definition thread_messaging.h:365
#define LOG_TRACE_FMT
Definition logger.h:378
uint16_t get_current_thread_id()
Definition thread_local.cpp:9
Definition app_interface.h:15
SessionStatus
Definition tls_session.h:19
@ closed
Definition tls_session.h:23
@ authfail
Definition tls_session.h:24
@ error
Definition tls_session.h:25
@ ready
Definition tls_session.h:21
@ closing
Definition tls_session.h:22
@ handshake
Definition tls_session.h:20
std::shared_ptr< AbstractWriter > WriterPtr
Definition ring_buffer_types.h:150
STL namespace.
int64_t ConnID
Definition msg_types.h:9
std::string error_string(int ec)
Definition tls.h:32
#define RINGBUFFER_TRY_WRITE_MESSAGE(MSG,...)
Definition ring_buffer_types.h:258
#define RINGBUFFER_WRITE_MESSAGE(MSG,...)
Definition ring_buffer_types.h:255
Definition serializer.h:27
Definition thread_messaging.h:27
#define TLS_ERR_X509_VERIFY
Definition tls.h:24
#define TLS_READING
Definition tls.h:14
#define TLS_ERR_WANT_WRITE
Definition tls.h:17
#define TLS_ERR_WANT_READ
Definition tls.h:16
#define TLS_WRITING
Definition tls.h:15
#define TLS_ERR_CONN_CLOSE_NOTIFY
Definition tls.h:18
#define TLS_ERR_NEED_CERT
Definition tls.h:19