COROUTINES EVERYWHERE
Alexey Kutumov
Senior software engineer at Kaspersky Lab
AGENDA
Ч resumable_function
Ч co_await
я C++
Fine tuning coroutine
Ч Е RESUMABLE_FUNCTION
RESUMABLE FUNCTION
auto ResumableFunction(std::string someArg) {
DoThis();
co_await StartSomeAsyncOp();
DoThat(someArg);
}
WINAPI SAMPLE (OLD SCHOOL)
struct AsyncContext : public OVERLAPPED {
HANDLE m_event;
AsyncContext()
: OVERLAPPED()
, m_event(::CreateEvent(nullptr, true, true, nullptr)) {
}
static void WINAPI OnWrite(DWORD error, DWORD bytesWritten, LPOVERLAPPED overlapped) {
AsyncContext* me = static_cast<AsyncContext*>(overlapped);
std::cout << "done, bytes written: " << bytesWritten << std::endl;
::SetEvent(me->m_event);
}
};
WINAPI SAMPLE (OLD SCHOOL)
void OldSchoolWinapiWaiter() {
const char data [] = "hello from good ol` C!";
auto file = ::CreateFile("1.txt", ..., FILE_FLAG_OVERLAPPED);
AsyncContext ctx;
std::cout << "starting..." << std::endl;
::WriteFileEx(file, data, sizeof(data) - 1, &ctx, AsyncContext::OnWrite);
::WaitForSingleObjectEx(ctx.m_event, INFINITE, true);
}
WINAPI SAMPLE (C++)
struct AsyncIoBase : private OVERLAPPED {
AsyncIoBase() : OVERLAPPED() {
}
OVERLAPPED* GetOverlapped() {
return static_cast<OVERLAPPED*>(this);
}
template <typename AsyncIoType>
static AsyncIoType* GetAsyncIo(OVERLAPPED* overlapped) {
return static_cast<AsyncIoType*>(overlapped);
}
};
WINAPI SAMPLE (C++)
struct AsyncWriterBase : public AsyncIoBase {
explicit AsyncWriterBase(HANDLE handle)
: AsyncIoBase(), m_handle(handle) {
}
void StartAsyncWrite(const void* data, size_t size) {
::WriteFileEx(m_handle, data, size, GetOverlapped(), OnWrite);
}
virtual void WriteFinished(DWORD bytesWritten) = 0;
static void WINAPI OnWrite(DWORD error, DWORD bytesWritten, LPOVERLAPPED overlapped) {
AsyncWriterBase* me = GetAsyncIo<AsyncWriterBase>(overlapped);
me->WriteFinished(bytesWritten);
}
HANDLE m_handle;
};
WINAPI SAMPLE (C++)
struct WinapiAsyncWriter : public AsyncWriterBase {
explicit WinapiAsyncWriter(HANDLE handle);
HANDLE WriteAsync(const std::string& message) {
m_message = message;
std::cout << "starting ..." << std::endl;
StartAsyncWrite(m_message.data(), m_message.size());
return m_promise;
}
virtual void WriteFinished(DWORD bytesWritten) {
std::cout << "done, bytes written: " << bytesWritten << std::endl;
::SetEvent(m_promise);
}
HANDLE m_promise;
std::string m_message;
};
WINAPI SAMPLE (C++)
void TryWinApiWaiter() {
auto fileHandle = ::CreateFile("2.txt", ..., FILE_FLAG_OVERLAPPED);
WinapiAsyncWriter asyncWriter{fileHandle};
auto future = asyncWriter.WriteAsync("hello world");
::WaitForSingleObjectEx(future, INFINITE, true);
}
WINAPI SAMPLE (CORO)
struct WinApiAwaitable : AsyncWriterBase {
WinApiAwaitable(HANDLE handle, const char* data, size_t size);
bool await_ready();
void await_suspend(std::coroutine_handle<coro::promise<void>> coroutineHandle);
DWORD await_resume();
private:
std::coroutine_handle<coro::promise<void>> m_coroutineHandle;
const void* m_data;
size_t m_size;
DWORD m_bytesWritten;
virtual void WriteFinished(DWORD bytesWritten);
};
WINAPI SAMPLE (CORO)
coro::future<void> WriteAsync(HANDLE handle, std::string message) {
std::cout << "starting ..." << std::endl;
DWORD bytesWritten = co_await WinApiAwaitable{handle, message.data(), message.size()};
std::cout << "done, bytes written: " << bytesWritten << std::endl;
}
void TrySimpleCoroWaiter() {
auto file = ::CreateFile(TEXT("3.txt"), ..., FILE_FLAG_OVERLAPPED);
auto future = WriteAsync(file, "hello coroutines!");
future.get_value();
}
WINAPI SAMPLE (CORO)
void WinApiAwaitable::await_suspend(std::coroutine_handle<promise_type> handle) {
m_coroutineHandle = handle;
StartAsyncWrite(m_data, m_size);
}
void WinApiAwaitable::WriteFinished(DWORD bytesWritten) {
m_bytesWritten = bytesWritten;
m_coroutineHandle.resume();
}
DWORD WinApiAwaitable::await_resume() {
return m_bytesWritten;
}
Е CO_AWAIT
CO_AWAIT
coro::future<void> WriteAsync(HANDLE handle, std::string message) {
std::cout << "starting ..." << std::endl;
auto bytesWritten = co_await WinApiAwaitable{handle, message.data(), message.size()};
std::cout << "done, bytes written: " << bytesWritten << std::endl;
}
CO_AWAIT
typedef void( *CoroResumeFunction)(void *);
enum class CoroState {Init, Resume};
template <typename PromiseType>
struct coro_frame_header
{
CoroResumeFunction resumeAddr;
PromiseType promise;
CoroState state
explicit coro_frame_header(CoroResumeFunction fn)
: resumeAddr(fn)
, promise()
, state(CoroState::Init){
}
};
CO_AWAIT
struct FunctionContext : coro_frame_header<coro::promise<void>> {
HANDLE handle;
std::string message;
WinApiAwaitable awaitable;
FunctionContext(HANDLE h, std::string m, CoroResumeFunction fn)
: frame_header(fn)
, handle(h)
, message(std::move(m))
, awaitable(handle, message.data(), message.size()) {
}
};
CO_AWAIT
template <typename PromiseType>
struct coroutine_handle {
coro_frame_header<PromiseType>* m_frame
coroutine_handle(coro_frame_header<PromiseType>* frame)
: m_frame(frame) {
}
void resume() {
m_frame->resumeAddr(m_frame);
}
void destroy();
};
CO_AWAIT
void WinApiAwaitable::await_suspend(std::coroutine_handle<promise_type> handle) {
m_coroutineHandle = handle;
StartAsyncWrite(m_data, m_size);
}
DWORD WinApiAwaitable::await_resume() {
return m_bytesWritten;
}
void WinApiAwaitable::WriteFinished(DWORD bytesWritten) {
m_bytesWritten = bytesWritten;
m_coroutineHandle.resume();
}
CO_AWAIT
coro::future<void> WriteAsync(HANDLE handle, std::string message) {
std::cout << "starting ..." << std::endl;
auto bytesWritten = co_await WinApiAwaitable{handle, message.data(), message.size()};
std::cout << "done, bytes written: " << bytesWritten << std::endl;
}
CO_AWAIT
void CoroFunction(void* rawCtx) {
FunctionContext* ctx = static_cast<FunctionContext*>(rawCtx);
if (ctx->state == CoroState::Init) {
ctx->state = CoroState::Resume;
std::cout << "starting ..." << std::endl;
ctx->awaitable.await_suspend(ctx);
} else {
DWORD bytesWritten = ctx->awaitable.await_resume();
std::cout << "done, bytes written: " << bytesWritten << std::endl;
delete ctx;
}
}
CO_AWAIT
coro::future<void> WriteAsync(HANDLE handle, std::string message) {
FunctionContext* ctx = new FunctionContext{handle, std::move(message), CoroFunction};
CoroFunction(ctx);
return ctx->promise.get_future();
}
Е C++
C++ RUNTIME
void* operator new(size_t size);
void operator delete(void* location, size_t);
ИЛИ
void* operator new(size_t size, std::nothrow_t const &);
void operator delete(void* location, size_t);
C++ RUNTIME
coro::future<void> WriteAsync(HANDLE handle, std::string message) {
FunctionContext* ctx = new FunctionContext{handle, std::move(message), CoroFunction};
CoroFunction(ctx);
return ctx->promise.get_future();
}
C++ RUNTIME
coro::future<void> WriteAsync(HANDLE handle, std::string message) {
typedef std::coroutine_traits<coro::future<void>, HANDLE, std::string> coro_traits;
auto* ctx = new (std::nothrow) FunctionContext{handle, std::move(message), CoroFunction};
if (!ctx) {
return coro_traits::get_return_object_on_allocation_failure();
}
CoroFunction(ctx);
return ctx->promise.get_future();
}
C++ RUNTIME
constexpr ULONG CrtPoolTag = 'CRTP';
void* operator new(size_t size, struct std::nothrow_t const &) {
return ::ExAllocatePoolWithTag(PagedPool, size, CrtPoolTag);
}
void operator delete(void* location, size_t) {
if (location) {
::ExFreePoolWithTag(location, CrtPoolTag);
}
}
FINE TUNING COROUTINE
THEORY (WINDOWS KERNEL + EFI)
(IRQ – km, Task – efi).
(Level) я (IRQL – km, TPL – efi).
я я.
LEVELS OF EXECUTION
Windows kernel
PASSIVE_LEVEL
APC_LEVEL
DISPATCH_LEVEL
…
DIRQL
…
HIGH_LEVEL
EFI
TPL_APPLICATION
TPL_CALLBACK
TPL_NOTIFY
FW interrupts
TPL_HIGH_LEVEL
INITIAL + FINAL SUSPEND POINTS
coro::future<void> WriteAsync(HANDLE handle, std::string message) {
auto __is = ???.initial_suspend();
__co_await __is;
std::cout << "starting ..." << std::endl;
auto bytesWritten = co_await WinApiAwaitable{handle, message.data(), message.size()};
std::cout << "done, bytes written: " << bytesWritten << std::endl;
auto __fs = ???.final_suspend();
__co_await __fs;
}
INITIAL SUSPEND
coro::future<void> WriteAsync(HANDLE handle, std::string message) {
auto* ctx = new (std::nothrow) FunctionContext{handle, std::move(message), CoroFunction};
if (!ctx) {return ...}
auto is = ctx->promise.initial_suspend();
if (is.await_ready()) {
is.await_resume();
CoroFunction(ctx);
} else {
is.await_suspend(ctx);
}
return ctx->promise.get_future();
}
CONSTRUCTORS IN NOEXCEPT ENVIRONMENT
auto* ctx = new (std::nothrow) FunctionContext{args};
::operator new(sizeof(FunctionContext), std::nothrow);
FunctionContext::FunctionContext(args);
coro::promise<void>::coro::promise<void>();
coro::promise<void>::construct_shared_state();
make_shared_nothrow<State>();
::CreateEvent(nullptr, true, true, nullptr)
INITIAL SUSPEND
struct check_for_broken_promise {
explicit check_for_broken_promise(bool valid) : m_valid(valid) {}
bool await_ready() noexcept { return m_valid; }
void await_suspend(std::experimental::coroutine_handle<promise> coroutine) noexcept {
coroutine.destroy();
}
void await_resume() noexcept {}
const bool m_valid;
};
template <typename Type>
check_for_broken_promise promise<Type>::initial_suspend() {
return check_for_broken_promise{detail::IsValid(m_state)};
}
FINAL SUSPEND
void CoroFunction(FunctionContext* ctx) {
if (ctx->state == CoroState::Init) {
/// starting...
} else {
DWORD bytesWritten = ctx->awaitable.await_resume();
std::cout << "done, bytes written: " << bytesWritten << std::endl;
auto fs = ctx->promise.final_suspend();
if (fs.await_ready()) {
fs.await_resume();
delete ctx; // destroy coro
} else {
fs.await_suspend(ctx);
}
}
}
SHARED_STATE_BASE
struct shared_state_base {
~shared_state_base() {
DestroyCoroutine();
}
void DestroyCoroutine() {
if (m_coroutine) {m_coroutine.destroy(); m_coroutine = nullptr;}
}
coro::error_code Notify(std::experimental::coroutine_handle<> coroutine = nullptr) {
m_coroutine = coroutine;
return m_event.Notify();
}
PlatformEventType m_event;
std::experimental::coroutine_handle<> m_coroutine;
};
SHARED_STATE_BASE
~future() {
if (m_state) {
m_state->DestroyCoroutine();
}
}
~promise() {
if (m_state) {
m_state->DestroyCoroutine();
}
}
FINAL_SUSPEND
struct notify_and_destroy_coroutine_in_caller_context {
bool await_ready() noexcept {
return false;
}
void await_suspend(std::experimental::coroutine_handle<promise> coroutine) noexcept {
coroutine.promise().notify(coroutine);
}
void await_resume() noexcept { }
};
FINAL_SUSPEND
struct notify_and_destroy_coroutine_in_async_context {
notify_and_destroy_coroutine_in_async_context(coro::shared_state_ptr state)
: m_state(state) { }
bool await_ready() noexcept { return true; }
void await_suspend(std::experimental::coroutine_handle<> /* coroutine */) noexcept { }
void await_resume() noexcept {
m_state->Notify();
}
coro::shared_state_ptr m_state;
};
LET'S TALK? [email protected]
https://github.com/prograholic/coroutines_km/
Top Related