CCF
Loading...
Searching...
No Matches
thread_messaging.h
Go to the documentation of this file.
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the Apache 2.0 License.
3#pragma once
4
5#include "ccf/ccf_assert.h"
6#include "ccf/ds/logger.h"
8
9#include <atomic>
10#include <chrono>
11#include <cstddef>
12
13namespace threading
14{
15 struct ThreadMsg
16 {
17 void (*cb)(std::unique_ptr<ThreadMsg>);
18 std::atomic<ThreadMsg*> next = nullptr;
19
20 ThreadMsg(void (*_cb)(std::unique_ptr<ThreadMsg>)) : cb(_cb) {}
21
22 virtual ~ThreadMsg() = default;
23 };
24
25 template <typename Payload>
26 struct alignas(16) Tmsg : public ThreadMsg
27 {
28 Payload data;
29
30 template <typename... Args>
31 Tmsg(void (*_cb)(std::unique_ptr<Tmsg<Payload>>), Args&&... args) :
32 ThreadMsg(reinterpret_cast<void (*)(std::unique_ptr<ThreadMsg>)>(_cb)),
33 data(std::forward<Args>(args)...)
34 {}
35
36 void reset_cb(void (*_cb)(std::unique_ptr<Tmsg<Payload>>))
37 {
38 cb = reinterpret_cast<void (*)(std::unique_ptr<ThreadMsg>)>(_cb);
39 }
40
41 virtual ~Tmsg() = default;
42 };
43
44 class ThreadMessaging;
45
47 {
48 std::atomic<ThreadMsg*> item_head = nullptr;
49 ThreadMsg* local_msg = nullptr;
50
51 public:
52 TaskQueue() = default;
53
55 {
56 if (local_msg == nullptr && item_head != nullptr)
57 {
58 local_msg = item_head.exchange(nullptr);
59 reverse_local_messages();
60 }
61
62 if (local_msg == nullptr)
63 {
64 return false;
65 }
66
67 ThreadMsg* current = local_msg;
68 local_msg = local_msg->next;
69
70 current->cb(std::unique_ptr<ThreadMsg>(current));
71 return true;
72 }
73
74 void add_task(ThreadMsg* item)
75 {
76 ThreadMsg* tmp_head;
77 do
78 {
79 tmp_head = item_head.load();
80 item->next = tmp_head;
81 } while (!item_head.compare_exchange_strong(tmp_head, item));
82 }
83
85 {
87 TimerEntry(std::chrono::milliseconds time_offset_, uint64_t counter_) :
88 time_offset(time_offset_),
89 counter(counter_)
90 {}
91
92 std::chrono::milliseconds time_offset;
93 uint64_t counter;
94 };
95
97 {
98 bool operator()(const TimerEntry& lhs, const TimerEntry& rhs) const
99 {
100 if (lhs.time_offset != rhs.time_offset)
101 {
102 return lhs.time_offset < rhs.time_offset;
103 }
104
105 return lhs.counter < rhs.counter;
106 }
107 };
108
110 std::unique_ptr<ThreadMsg> item, std::chrono::milliseconds ms)
111 {
112 TimerEntry entry = {time_offset + ms, time_entry_counter++};
113 if (timer_map.empty() || entry.time_offset <= next_time_offset)
114 {
115 next_time_offset = entry.time_offset;
116 }
117
118 timer_map.emplace(entry, std::move(item));
119 return entry;
120 }
121
123 {
124 auto num_erased = timer_map.erase(timer_entry);
125 CCF_ASSERT(num_erased <= 1, "Too many items erased");
126 if (!timer_map.empty() && timer_entry.time_offset <= next_time_offset)
127 {
128 next_time_offset = timer_map.begin()->first.time_offset;
129 }
130 return num_erased != 0;
131 }
132
133 void tick(std::chrono::milliseconds elapsed)
134 {
135 time_offset += elapsed;
136
137 bool updated = false;
138
139 while (!timer_map.empty() && next_time_offset <= time_offset &&
140 timer_map.begin()->first.time_offset <= time_offset)
141 {
142 updated = true;
143 auto it = timer_map.begin();
144
145 auto& cb = it->second->cb;
146 auto msg = std::move(it->second);
147 timer_map.erase(it);
148 cb(std::move(msg));
149 }
150
151 if (updated && !timer_map.empty())
152 {
153 next_time_offset = timer_map.begin()->first.time_offset;
154 }
155 }
156
157 std::chrono::milliseconds get_current_time_offset()
158 {
159 return time_offset;
160 }
161
162 private:
163 std::chrono::milliseconds time_offset = std::chrono::milliseconds(0);
164 uint64_t time_entry_counter = 0;
165 std::map<TimerEntry, std::unique_ptr<ThreadMsg>, TimerEntryCompare>
166 timer_map;
167 std::chrono::milliseconds next_time_offset;
168
169 void reverse_local_messages()
170 {
171 if (local_msg == nullptr)
172 return;
173
174 ThreadMsg *prev = nullptr, *current = nullptr, *next = nullptr;
175 current = local_msg;
176 while (current != nullptr)
177 {
178 next = current->next;
179 current->next = prev;
180 prev = current;
181 current = next;
182 }
183 // now let the head point at the last node (prev)
184 local_msg = prev;
185 }
186
187 void drop()
188 {
189 while (true)
190 {
191 if (local_msg == nullptr && item_head != nullptr)
192 {
193 local_msg = item_head.exchange(nullptr);
194 reverse_local_messages();
195 }
196
197 if (local_msg == nullptr)
198 {
199 break;
200 }
201
202 ThreadMsg* current = local_msg;
203 local_msg = local_msg->next;
204 delete current;
205 }
206 }
207
208 friend ThreadMessaging;
209 };
210
212 {
213 std::atomic<bool> finished;
214 std::vector<TaskQueue> tasks; // Fixed-size at construction
215
216 // Drop all pending tasks, this is only ever to be used
217 // on shutdown, to avoid leaks, and after all thread but
218 // the main one have been shut down.
219 void drop_tasks()
220 {
221 for (auto& t : tasks)
222 {
223 t.drop();
224 }
225 }
226
227 inline TaskQueue& get_tasks(uint16_t task_id)
228 {
229 if (task_id >= tasks.size())
230 {
231 throw std::runtime_error(fmt::format(
232 "Attempting to access task_id >= task_count, task_id:{}, "
233 "task_count:{}",
234 task_id,
235 tasks.size()));
236 }
237 return tasks[task_id];
238 }
239
240 static std::unique_ptr<ThreadMessaging> singleton;
241
242 public:
243 static constexpr uint16_t max_num_threads = 24;
244
245 ThreadMessaging(uint16_t num_task_queues) :
246 finished(false),
247 tasks(num_task_queues)
248 {
249 if (num_task_queues > max_num_threads)
250 {
251 throw std::logic_error(fmt::format(
252 "ThreadMessaging constructed with too many tasks: {} > {}",
253 num_task_queues,
255 }
256 }
257
259 {
260 drop_tasks();
261 }
262
263 static void init(uint16_t num_task_queues)
264 {
265 if (singleton != nullptr)
266 {
267 throw std::logic_error("Called init() multiple times");
268 }
269
270 singleton = std::make_unique<ThreadMessaging>(num_task_queues);
271 }
272
273 static void shutdown()
274 {
275 singleton.reset();
276 }
277
279 {
280 if (singleton == nullptr)
281 {
282 throw std::logic_error(
283 "Attempted to access global ThreadMessaging instance without first "
284 "calling init()");
285 }
286
287 return *singleton;
288 }
289
290 void set_finished(bool v = true)
291 {
292 finished.store(v);
293 }
294
295 void run()
296 {
298
299 while (!is_finished())
300 {
301 task.run_next_task();
302 }
303 }
304
305 bool run_one()
306 {
308 return task.run_next_task();
309 }
310
311 template <typename Payload>
312 void add_task(uint16_t tid, std::unique_ptr<Tmsg<Payload>> msg)
313 {
314 TaskQueue& task = get_tasks(tid);
315
316 task.add_task(reinterpret_cast<ThreadMsg*>(msg.release()));
317 }
318
319 template <typename Payload>
321 std::unique_ptr<Tmsg<Payload>> msg, std::chrono::milliseconds ms)
322 {
324 return task.add_task_after(std::move(msg), ms);
325 }
326
328 {
330 return task.cancel_timer_task(timer_entry);
331 }
332
333 std::chrono::milliseconds get_current_time_offset()
334 {
336 return task.get_current_time_offset();
337 }
338
339 struct TickMsg
340 {
341 TickMsg(std::chrono::milliseconds elapsed_, TaskQueue& task_) :
342 elapsed(elapsed_),
343 task(task_)
344 {}
345
346 std::chrono::milliseconds elapsed;
348 };
349
350 static void tick_cb(std::unique_ptr<Tmsg<TickMsg>> msg)
351 {
352 msg->data.task.tick(msg->data.elapsed);
353 }
354
355 void tick(std::chrono::milliseconds elapsed)
356 {
357 for (auto i = 0ul; i < tasks.size(); ++i)
358 {
359 auto& task = get_tasks(i);
360 auto msg = std::make_unique<Tmsg<TickMsg>>(&tick_cb, elapsed, task);
361 task.add_task(msg.release());
362 }
363 }
364
365 uint16_t get_execution_thread(uint32_t i)
366 {
367 uint16_t tid = ccf::threading::MAIN_THREAD_ID;
368 if (tasks.size() > 1)
369 {
370 // If we have multiple task queues, then we distinguish the main thread
371 // from the remaining workers; anything asking for an execution thread
372 // does _not_ go to the main thread's queue
373 tid = (i % (tasks.size() - 1));
374 ++tid;
375 }
376
377 return tid;
378 }
379
380 uint16_t thread_count() const
381 {
382 return tasks.size();
383 }
384
385 private:
386 bool is_finished()
387 {
388 return finished.load();
389 }
390 };
391};
#define CCF_ASSERT(expr, msg)
Definition ccf_assert.h:14
Definition thread_messaging.h:47
void tick(std::chrono::milliseconds elapsed)
Definition thread_messaging.h:133
bool run_next_task()
Definition thread_messaging.h:54
void add_task(ThreadMsg *item)
Definition thread_messaging.h:74
std::chrono::milliseconds get_current_time_offset()
Definition thread_messaging.h:157
TimerEntry add_task_after(std::unique_ptr< ThreadMsg > item, std::chrono::milliseconds ms)
Definition thread_messaging.h:109
bool cancel_timer_task(TimerEntry timer_entry)
Definition thread_messaging.h:122
Definition thread_messaging.h:212
ThreadMessaging(uint16_t num_task_queues)
Definition thread_messaging.h:245
void tick(std::chrono::milliseconds elapsed)
Definition thread_messaging.h:355
void run()
Definition thread_messaging.h:295
std::chrono::milliseconds get_current_time_offset()
Definition thread_messaging.h:333
static void shutdown()
Definition thread_messaging.h:273
uint16_t thread_count() const
Definition thread_messaging.h:380
static ThreadMessaging & instance()
Definition thread_messaging.h:278
bool cancel_timer_task(TaskQueue::TimerEntry timer_entry)
Definition thread_messaging.h:327
bool run_one()
Definition thread_messaging.h:305
TaskQueue::TimerEntry add_task_after(std::unique_ptr< Tmsg< Payload > > msg, std::chrono::milliseconds ms)
Definition thread_messaging.h:320
void add_task(uint16_t tid, std::unique_ptr< Tmsg< Payload > > msg)
Definition thread_messaging.h:312
static void init(uint16_t num_task_queues)
Definition thread_messaging.h:263
void set_finished(bool v=true)
Definition thread_messaging.h:290
~ThreadMessaging()
Definition thread_messaging.h:258
static void tick_cb(std::unique_ptr< Tmsg< TickMsg > > msg)
Definition thread_messaging.h:350
uint16_t get_execution_thread(uint32_t i)
Definition thread_messaging.h:365
static constexpr uint16_t max_num_threads
Definition thread_messaging.h:243
uint16_t get_current_thread_id()
Definition thread_local.cpp:9
STL namespace.
Definition thread_messaging.h:14
Definition thread_messaging.h:97
bool operator()(const TimerEntry &lhs, const TimerEntry &rhs) const
Definition thread_messaging.h:98
Definition thread_messaging.h:85
TimerEntry(std::chrono::milliseconds time_offset_, uint64_t counter_)
Definition thread_messaging.h:87
std::chrono::milliseconds time_offset
Definition thread_messaging.h:92
TimerEntry()
Definition thread_messaging.h:86
uint64_t counter
Definition thread_messaging.h:93
Definition thread_messaging.h:340
std::chrono::milliseconds elapsed
Definition thread_messaging.h:346
TaskQueue & task
Definition thread_messaging.h:347
TickMsg(std::chrono::milliseconds elapsed_, TaskQueue &task_)
Definition thread_messaging.h:341
Definition thread_messaging.h:16
void(* cb)(std::unique_ptr< ThreadMsg >)
Definition thread_messaging.h:17
std::atomic< ThreadMsg * > next
Definition thread_messaging.h:18
virtual ~ThreadMsg()=default
ThreadMsg(void(*_cb)(std::unique_ptr< ThreadMsg >))
Definition thread_messaging.h:20
Definition thread_messaging.h:27
Tmsg(void(*_cb)(std::unique_ptr< Tmsg< Payload > >), Args &&... args)
Definition thread_messaging.h:31
virtual ~Tmsg()=default
void reset_cb(void(*_cb)(std::unique_ptr< Tmsg< Payload > >))
Definition thread_messaging.h:36
Payload data
Definition thread_messaging.h:28