1  
//
1  
//
2  
// Copyright (c) 2025 Vinnie Falco (vinnie.falco@gmail.com)
2  
// Copyright (c) 2025 Vinnie Falco (vinnie.falco@gmail.com)
3  
//
3  
//
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6  
//
6  
//
7  
// Official repository: https://github.com/cppalliance/capy
7  
// Official repository: https://github.com/cppalliance/capy
8  
//
8  
//
9  

9  

10  
#include "src/ex/detail/strand_queue.hpp"
10  
#include "src/ex/detail/strand_queue.hpp"
11  
#include <boost/capy/ex/detail/strand_service.hpp>
11  
#include <boost/capy/ex/detail/strand_service.hpp>
12  
#include <atomic>
12  
#include <atomic>
13  
#include <coroutine>
13  
#include <coroutine>
14  
#include <mutex>
14  
#include <mutex>
15  
#include <thread>
15  
#include <thread>
16  
#include <utility>
16  
#include <utility>
17  

17  

18  
namespace boost {
18  
namespace boost {
19  
namespace capy {
19  
namespace capy {
20  
namespace detail {
20  
namespace detail {
21  

21  

22  
//----------------------------------------------------------
22  
//----------------------------------------------------------
23  

23  

24  
/** Implementation state for a strand.
24  
/** Implementation state for a strand.
25  

25  

26  
    Each strand_impl provides serialization for coroutines
26  
    Each strand_impl provides serialization for coroutines
27  
    dispatched through strands that share it.
27  
    dispatched through strands that share it.
28  
*/
28  
*/
 
29 +
// Sentinel stored in cached_frame_ after shutdown to prevent
 
30 +
// in-flight invokers from repopulating a freed cache slot.
 
31 +
inline void* const kCacheClosed = reinterpret_cast<void*>(1);
 
32 +

29  
struct strand_impl
33  
struct strand_impl
30  
{
34  
{
31  
    std::mutex mutex_;
35  
    std::mutex mutex_;
32  
    strand_queue pending_;
36  
    strand_queue pending_;
33  
    bool locked_ = false;
37  
    bool locked_ = false;
34  
    std::atomic<std::thread::id> dispatch_thread_{};
38  
    std::atomic<std::thread::id> dispatch_thread_{};
35 -
    void* cached_frame_ = nullptr;
39 +
    std::atomic<void*> cached_frame_{nullptr};
36  
};
40  
};
37  

41  

38  
//----------------------------------------------------------
42  
//----------------------------------------------------------
39  

43  

40  
/** Invoker coroutine for strand dispatch.
44  
/** Invoker coroutine for strand dispatch.
41  

45  

42  
    Uses custom allocator to recycle frame - one allocation
46  
    Uses custom allocator to recycle frame - one allocation
43  
    per strand_impl lifetime, stored in trailer for recovery.
47  
    per strand_impl lifetime, stored in trailer for recovery.
44  
*/
48  
*/
45  
struct strand_invoker
49  
struct strand_invoker
46  
{
50  
{
47  
    struct promise_type
51  
    struct promise_type
48  
    {
52  
    {
49  
        void* operator new(std::size_t n, strand_impl& impl)
53  
        void* operator new(std::size_t n, strand_impl& impl)
50  
        {
54  
        {
51  
            constexpr auto A = alignof(strand_impl*);
55  
            constexpr auto A = alignof(strand_impl*);
52  
            std::size_t padded = (n + A - 1) & ~(A - 1);
56  
            std::size_t padded = (n + A - 1) & ~(A - 1);
53  
            std::size_t total = padded + sizeof(strand_impl*);
57  
            std::size_t total = padded + sizeof(strand_impl*);
54  

58  

55 -
            void* p = impl.cached_frame_
59 +
            void* p = impl.cached_frame_.exchange(
56 -
                ? std::exchange(impl.cached_frame_, nullptr)
60 +
                nullptr, std::memory_order_acquire);
57 -
                : ::operator new(total);
61 +
            if(!p || p == kCacheClosed)
 
62 +
                p = ::operator new(total);
58  

63  

59  
            // Trailer lets delete recover impl
64  
            // Trailer lets delete recover impl
60  
            *reinterpret_cast<strand_impl**>(
65  
            *reinterpret_cast<strand_impl**>(
61  
                static_cast<char*>(p) + padded) = &impl;
66  
                static_cast<char*>(p) + padded) = &impl;
62  
            return p;
67  
            return p;
63  
        }
68  
        }
64  

69  

65  
        void operator delete(void* p, std::size_t n) noexcept
70  
        void operator delete(void* p, std::size_t n) noexcept
66  
        {
71  
        {
67  
            constexpr auto A = alignof(strand_impl*);
72  
            constexpr auto A = alignof(strand_impl*);
68  
            std::size_t padded = (n + A - 1) & ~(A - 1);
73  
            std::size_t padded = (n + A - 1) & ~(A - 1);
69  

74  

70  
            auto* impl = *reinterpret_cast<strand_impl**>(
75  
            auto* impl = *reinterpret_cast<strand_impl**>(
71  
                static_cast<char*>(p) + padded);
76  
                static_cast<char*>(p) + padded);
72  

77  

73 -
            if (!impl->cached_frame_)
78 +
            void* expected = nullptr;
74 -
                impl->cached_frame_ = p;
79 +
            if(!impl->cached_frame_.compare_exchange_strong(
75 -
            else
80 +
                expected, p, std::memory_order_release))
76  
                ::operator delete(p);
81  
                ::operator delete(p);
77  
        }
82  
        }
78  

83  

79  
        strand_invoker get_return_object() noexcept
84  
        strand_invoker get_return_object() noexcept
80  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
85  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
81  

86  

82  
        std::suspend_always initial_suspend() noexcept { return {}; }
87  
        std::suspend_always initial_suspend() noexcept { return {}; }
83  
        std::suspend_never final_suspend() noexcept { return {}; }
88  
        std::suspend_never final_suspend() noexcept { return {}; }
84  
        void return_void() noexcept {}
89  
        void return_void() noexcept {}
85  
        void unhandled_exception() { std::terminate(); }
90  
        void unhandled_exception() { std::terminate(); }
86  
    };
91  
    };
87  

92  

88  
    std::coroutine_handle<promise_type> h_;
93  
    std::coroutine_handle<promise_type> h_;
89  
};
94  
};
90  

95  

91  
//----------------------------------------------------------
96  
//----------------------------------------------------------
92  

97  

93  
/** Concrete implementation of strand_service.
98  
/** Concrete implementation of strand_service.
94  

99  

95  
    Holds the fixed pool of strand_impl objects.
100  
    Holds the fixed pool of strand_impl objects.
96  
*/
101  
*/
97  
class strand_service_impl : public strand_service
102  
class strand_service_impl : public strand_service
98  
{
103  
{
99  
    static constexpr std::size_t num_impls = 211;
104  
    static constexpr std::size_t num_impls = 211;
100  

105  

101  
    strand_impl impls_[num_impls];
106  
    strand_impl impls_[num_impls];
102  
    std::size_t salt_ = 0;
107  
    std::size_t salt_ = 0;
103  
    std::mutex mutex_;
108  
    std::mutex mutex_;
104  

109  

105  
public:
110  
public:
106  
    explicit
111  
    explicit
107  
    strand_service_impl(execution_context&)
112  
    strand_service_impl(execution_context&)
108  
    {
113  
    {
109  
    }
114  
    }
110  

115  

111  
    strand_impl*
116  
    strand_impl*
112  
    get_implementation() override
117  
    get_implementation() override
113  
    {
118  
    {
114  
        std::lock_guard<std::mutex> lock(mutex_);
119  
        std::lock_guard<std::mutex> lock(mutex_);
115  
        std::size_t index = salt_++;
120  
        std::size_t index = salt_++;
116  
        index = index % num_impls;
121  
        index = index % num_impls;
117  
        return &impls_[index];
122  
        return &impls_[index];
118  
    }
123  
    }
119  

124  

120  
protected:
125  
protected:
121  
    void
126  
    void
122  
    shutdown() override
127  
    shutdown() override
123  
    {
128  
    {
124  
        for(std::size_t i = 0; i < num_impls; ++i)
129  
        for(std::size_t i = 0; i < num_impls; ++i)
125  
        {
130  
        {
126  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
131  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
127  
            impls_[i].locked_ = true;
132  
            impls_[i].locked_ = true;
128  

133  

129 -
            if(impls_[i].cached_frame_)
134 +
            void* p = impls_[i].cached_frame_.exchange(
130 -
            {
135 +
                kCacheClosed, std::memory_order_acquire);
131 -
                ::operator delete(impls_[i].cached_frame_);
136 +
            if(p)
132 -
                impls_[i].cached_frame_ = nullptr;
137 +
                ::operator delete(p);
133 -
            }
 
134  
        }
138  
        }
135  
    }
139  
    }
136  

140  

137  
private:
141  
private:
138  
    static bool
142  
    static bool
139  
    enqueue(strand_impl& impl, std::coroutine_handle<> h)
143  
    enqueue(strand_impl& impl, std::coroutine_handle<> h)
140  
    {
144  
    {
141  
        std::lock_guard<std::mutex> lock(impl.mutex_);
145  
        std::lock_guard<std::mutex> lock(impl.mutex_);
142  
        impl.pending_.push(h);
146  
        impl.pending_.push(h);
143  
        if(!impl.locked_)
147  
        if(!impl.locked_)
144  
        {
148  
        {
145  
            impl.locked_ = true;
149  
            impl.locked_ = true;
146  
            return true;
150  
            return true;
147  
        }
151  
        }
148  
        return false;
152  
        return false;
149  
    }
153  
    }
150  

154  

151  
    static void
155  
    static void
152  
    dispatch_pending(strand_impl& impl)
156  
    dispatch_pending(strand_impl& impl)
153  
    {
157  
    {
154  
        strand_queue::taken_batch batch;
158  
        strand_queue::taken_batch batch;
155  
        {
159  
        {
156  
            std::lock_guard<std::mutex> lock(impl.mutex_);
160  
            std::lock_guard<std::mutex> lock(impl.mutex_);
157  
            batch = impl.pending_.take_all();
161  
            batch = impl.pending_.take_all();
158  
        }
162  
        }
159  
        impl.pending_.dispatch_batch(batch);
163  
        impl.pending_.dispatch_batch(batch);
160  
    }
164  
    }
161  

165  

162  
    static bool
166  
    static bool
163  
    try_unlock(strand_impl& impl)
167  
    try_unlock(strand_impl& impl)
164  
    {
168  
    {
165  
        std::lock_guard<std::mutex> lock(impl.mutex_);
169  
        std::lock_guard<std::mutex> lock(impl.mutex_);
166  
        if(impl.pending_.empty())
170  
        if(impl.pending_.empty())
167  
        {
171  
        {
168  
            impl.locked_ = false;
172  
            impl.locked_ = false;
169  
            return true;
173  
            return true;
170  
        }
174  
        }
171  
        return false;
175  
        return false;
172  
    }
176  
    }
173  

177  

174  
    static void
178  
    static void
175  
    set_dispatch_thread(strand_impl& impl) noexcept
179  
    set_dispatch_thread(strand_impl& impl) noexcept
176  
    {
180  
    {
177  
        impl.dispatch_thread_.store(std::this_thread::get_id());
181  
        impl.dispatch_thread_.store(std::this_thread::get_id());
178  
    }
182  
    }
179  

183  

180  
    static void
184  
    static void
181  
    clear_dispatch_thread(strand_impl& impl) noexcept
185  
    clear_dispatch_thread(strand_impl& impl) noexcept
182  
    {
186  
    {
183  
        impl.dispatch_thread_.store(std::thread::id{});
187  
        impl.dispatch_thread_.store(std::thread::id{});
184  
    }
188  
    }
185  

189  

186  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
190  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
187  
    // (repost after each batch to let other work run) - explore if starvation observed.
191  
    // (repost after each batch to let other work run) - explore if starvation observed.
188  
    static strand_invoker
192  
    static strand_invoker
189  
    make_invoker(strand_impl& impl)
193  
    make_invoker(strand_impl& impl)
190  
    {
194  
    {
191  
        strand_impl* p = &impl;
195  
        strand_impl* p = &impl;
192  
        for(;;)
196  
        for(;;)
193  
        {
197  
        {
194  
            set_dispatch_thread(*p);
198  
            set_dispatch_thread(*p);
195  
            dispatch_pending(*p);
199  
            dispatch_pending(*p);
196  
            if(try_unlock(*p))
200  
            if(try_unlock(*p))
197  
            {
201  
            {
198  
                clear_dispatch_thread(*p);
202  
                clear_dispatch_thread(*p);
199  
                co_return;
203  
                co_return;
200  
            }
204  
            }
201  
        }
205  
        }
202  
    }
206  
    }
203  

207  

204  
    friend class strand_service;
208  
    friend class strand_service;
205  
};
209  
};
206  

210  

207  
//----------------------------------------------------------
211  
//----------------------------------------------------------
208  

212  

209  
strand_service::
213  
strand_service::
210  
strand_service()
214  
strand_service()
211  
    : service()
215  
    : service()
212  
{
216  
{
213  
}
217  
}
214  

218  

215  
strand_service::
219  
strand_service::
216  
~strand_service() = default;
220  
~strand_service() = default;
217  

221  

218  
bool
222  
bool
219  
strand_service::
223  
strand_service::
220  
running_in_this_thread(strand_impl& impl) noexcept
224  
running_in_this_thread(strand_impl& impl) noexcept
221  
{
225  
{
222  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
226  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
223  
}
227  
}
224  

228  

225  
std::coroutine_handle<>
229  
std::coroutine_handle<>
226  
strand_service::
230  
strand_service::
227  
dispatch(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
231  
dispatch(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
228  
{
232  
{
229  
    if(running_in_this_thread(impl))
233  
    if(running_in_this_thread(impl))
230  
        return h;
234  
        return h;
231  

235  

232  
    if(strand_service_impl::enqueue(impl, h))
236  
    if(strand_service_impl::enqueue(impl, h))
233  
        ex.post(strand_service_impl::make_invoker(impl).h_);
237  
        ex.post(strand_service_impl::make_invoker(impl).h_);
234  
    return std::noop_coroutine();
238  
    return std::noop_coroutine();
235  
}
239  
}
236  

240  

237  
void
241  
void
238  
strand_service::
242  
strand_service::
239  
post(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
243  
post(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
240  
{
244  
{
241  
    if(strand_service_impl::enqueue(impl, h))
245  
    if(strand_service_impl::enqueue(impl, h))
242  
        ex.post(strand_service_impl::make_invoker(impl).h_);
246  
        ex.post(strand_service_impl::make_invoker(impl).h_);
243  
}
247  
}
244  

248  

245  
strand_service&
249  
strand_service&
246  
get_strand_service(execution_context& ctx)
250  
get_strand_service(execution_context& ctx)
247  
{
251  
{
248  
    return ctx.use_service<strand_service_impl>();
252  
    return ctx.use_service<strand_service_impl>();
249  
}
253  
}
250  

254  

251  
} // namespace detail
255  
} // namespace detail
252  
} // namespace capy
256  
} // namespace capy
253  
} // namespace boost
257  
} // namespace boost