mirror of
https://github.com/micropython/micropython.git
synced 2025-07-21 04:51:12 +02:00
py/malloc: Add mutex for tracked allocations.
Fixes thread safety issue that could cause memory corruption on ports with (MICROPY_PY_THREAD && !MICROPY_PY_THREAD_GIL) - currently only rp2 and unix have this configuration. Adds unit test for TLS sockets that exercises this code path. I wasn't able to make this fail on rp2, the race condition window is pretty narrow and may not have a direct impact on a quiet system. This work was funded through GitHub Sponsors. Signed-off-by: Angus Gratton <angus@redyak.com.au>
This commit is contained in:
committed by
Damien George
parent
bee1fd5e78
commit
70ed315193
33
py/malloc.c
33
py/malloc.c
@@ -209,6 +209,31 @@ void m_free(void *ptr)
|
||||
|
||||
#if MICROPY_TRACKED_ALLOC
|
||||
|
||||
#if MICROPY_PY_THREAD && !MICROPY_PY_THREAD_GIL
|
||||
// If there's no GIL, use the GC recursive mutex to protect the tracked node linked list
|
||||
// under m_tracked_head.
|
||||
//
|
||||
// (For ports with GIL, the expectation is to only call tracked alloc functions
|
||||
// while holding the GIL.)
|
||||
|
||||
static inline void m_tracked_node_lock(void) {
|
||||
mp_thread_recursive_mutex_lock(&MP_STATE_MEM(gc_mutex), 1);
|
||||
}
|
||||
|
||||
static inline void m_tracked_node_unlock(void) {
|
||||
mp_thread_recursive_mutex_unlock(&MP_STATE_MEM(gc_mutex));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
static inline void m_tracked_node_lock(void) {
|
||||
}
|
||||
|
||||
static inline void m_tracked_node_unlock(void) {
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#define MICROPY_TRACKED_ALLOC_STORE_SIZE (!MICROPY_ENABLE_GC)
|
||||
|
||||
typedef struct _m_tracked_node_t {
|
||||
@@ -222,6 +247,7 @@ typedef struct _m_tracked_node_t {
|
||||
|
||||
#if MICROPY_DEBUG_VERBOSE
|
||||
static size_t m_tracked_count_links(size_t *nb) {
|
||||
m_tracked_node_lock();
|
||||
m_tracked_node_t *node = MP_STATE_VM(m_tracked_head);
|
||||
size_t n = 0;
|
||||
*nb = 0;
|
||||
@@ -234,6 +260,7 @@ static size_t m_tracked_count_links(size_t *nb) {
|
||||
#endif
|
||||
node = node->next;
|
||||
}
|
||||
m_tracked_node_unlock();
|
||||
return n;
|
||||
}
|
||||
#endif
|
||||
@@ -248,12 +275,14 @@ void *m_tracked_calloc(size_t nmemb, size_t size) {
|
||||
size_t n = m_tracked_count_links(&nb);
|
||||
DEBUG_printf("m_tracked_calloc(%u, %u) -> (%u;%u) %p\n", (int)nmemb, (int)size, (int)n, (int)nb, node);
|
||||
#endif
|
||||
m_tracked_node_lock();
|
||||
if (MP_STATE_VM(m_tracked_head) != NULL) {
|
||||
MP_STATE_VM(m_tracked_head)->prev = node;
|
||||
}
|
||||
node->prev = NULL;
|
||||
node->next = MP_STATE_VM(m_tracked_head);
|
||||
MP_STATE_VM(m_tracked_head) = node;
|
||||
m_tracked_node_unlock();
|
||||
#if MICROPY_TRACKED_ALLOC_STORE_SIZE
|
||||
node->size = nmemb * size;
|
||||
#endif
|
||||
@@ -278,7 +307,8 @@ void m_tracked_free(void *ptr_in) {
|
||||
size_t nb;
|
||||
size_t n = m_tracked_count_links(&nb);
|
||||
DEBUG_printf("m_tracked_free(%p, [%p, %p], nbytes=%u, links=%u;%u)\n", node, node->prev, node->next, (int)data_bytes, (int)n, (int)nb);
|
||||
#endif
|
||||
#endif // MICROPY_DEBUG_VERBOSE
|
||||
m_tracked_node_lock();
|
||||
if (node->next != NULL) {
|
||||
node->next->prev = node->prev;
|
||||
}
|
||||
@@ -287,6 +317,7 @@ void m_tracked_free(void *ptr_in) {
|
||||
} else {
|
||||
MP_STATE_VM(m_tracked_head) = node->next;
|
||||
}
|
||||
m_tracked_node_unlock();
|
||||
m_free(node
|
||||
#if MICROPY_MALLOC_USES_ALLOCATED_SIZE
|
||||
#if MICROPY_TRACKED_ALLOC_STORE_SIZE
|
||||
|
57
tests/extmod/ssl_threads.py
Normal file
57
tests/extmod/ssl_threads.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Ensure that SSL sockets can be allocated from multiple
|
||||
# threads without thread safety issues
|
||||
import unittest
|
||||
|
||||
try:
|
||||
import _thread
|
||||
import io
|
||||
import tls
|
||||
import time
|
||||
except ImportError:
|
||||
print("SKIP")
|
||||
raise SystemExit
|
||||
|
||||
|
||||
class TestSocket(io.IOBase):
|
||||
def write(self, buf):
|
||||
return len(buf)
|
||||
|
||||
def readinto(self, buf):
|
||||
return 0
|
||||
|
||||
def ioctl(self, cmd, arg):
|
||||
return 0
|
||||
|
||||
def setblocking(self, value):
|
||||
pass
|
||||
|
||||
|
||||
ITERS = 256
|
||||
|
||||
|
||||
class TLSThreads(unittest.TestCase):
|
||||
def test_sslsocket_threaded(self):
|
||||
self.done = False
|
||||
# only run in two threads: too much RAM demand otherwise, and rp2 only
|
||||
# supports two anyhow
|
||||
_thread.start_new_thread(self._alloc_many_sockets, (True,))
|
||||
self._alloc_many_sockets(False)
|
||||
while not self.done:
|
||||
time.sleep(0.1)
|
||||
print("done")
|
||||
|
||||
def _alloc_many_sockets(self, set_done_flag):
|
||||
print("start", _thread.get_ident())
|
||||
ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT)
|
||||
ctx.verify_mode = tls.CERT_NONE
|
||||
for n in range(ITERS):
|
||||
s = TestSocket()
|
||||
s = ctx.wrap_socket(s, do_handshake_on_connect=False)
|
||||
s.close() # Free associated resources now from thread, not in a GC pass
|
||||
print("done", _thread.get_ident())
|
||||
if set_done_flag:
|
||||
self.done = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user