Async patterns in Python: DAG traversal
One of the first pieces of advice you receive when starting to write async Python code is to not block the thread running the event loop, as otherwise no other task is allowed to progress, essentially defeating the purpose of writing async code in the first place. So, I quickly learned to find and replace all blocking calls with asyncio variants, e.g. time.sleep(1.0)
became await asyncio.sleep(1.0)
, and packages like requests
were replaced by alternatives that offer async APIs, like aiohttp
or httpx
. In several instances, this quickly become a huge task, as the process turns recursive quickly: Python coroutines have to be called from another coroutine, and then that coroutine needs to be called from another coroutine, all the way to the top of the call stack.
But of course, it was all worth it! As soon as I finished await
-ing all those coroutines I would start rejoicing in the performance benefits of concurrent work: “I pay for 100% of the CPU, I am going to use 100% of the CPU! No idle time!”.
1import asyncio
2import enum
3
4NodeLabel = str
5NodeIndex = int
6
7class NodeStatus(enum.Enum):
8 READY = 0
9 COMPLETED = 1
10 FAILED = 2
11 FAILED_PARENT = 3
12
13class Node:
14 label: NodeLabel
15
16class QueueMessage:
17 node_index: NodeIndex
18 status: NodeStatus
19
20
21Edge = tuple[NodeIndex, NodeIndex]
22class DAG:
23 nodes: list[Node]
24 edges: list[Edge]
25
26async def traverse_dag(dag: DAG):
27 root_nodes = set(DAG.nodes)
28 direct_children = {}
29
30 for edge in DAG.edges:
31 if edge[1] in root_nodes:
32 root_nodes.remove(edge[1])
33
34 try:
35 direct_children[edge[0]].add(edge[1])
36 except KeyError:
37 direct_children[edge[0]] = set((edge[1],))
38
39 for root_node_index in root_nodes:
40 queue.put_nowait(QueueMessage(node_index=root_node_index, status=NodeStatus.READY))
41
42 background_tasks = set()
43 processed_nodes = set()
44
45 while len(processed_nodes) < len(dag.nodes):
46 match queue.get():
47 case QueueMessage(node_index=node_index, status=NodeStatus.READY):
48 node = DAG.nodes[node_index]
49 work_task = asyncio.create_task(do_work_on_node(node=node))
50
51 background_tasks.add(work_task)
52
53 async def work_task_callback(task):
54 background_tasks.remove(task)
55
56 if task.exception is None:
57 await queue.put(QueueMessage(node_index=node_index, status=NodeStatus.COMPLETED))
58 else:
59 await queue.put(QueueMessage(node_index=node_index, status=NodeStatus.FAILED))
60
61 work_task.add_done_callback(work_task_callback)
62
63 case QueueMessage(node_index=node_index, status=NodeStatus.COMPLETED):
64 processed_nodes.add(DAG.nodes[node_index])
65
66 case QueueMessage(node_index=node_index, status=NodeStatus.FAILED) | QueueMessage(node_index=node_index, status=NodeStatus.FAILED_PARENT):
67 processed_nodes.add(DAG.nodes[node_index])
68
69 async def enqueue_children_as_failed_parent():
70 for children_node_index in direct_children[node_index]:
71 await queue.put(QueueMessage(node_index=children_node_index, status=NodeStatus.FAILED_PARENT))
72
73
74async def do_work_on_node(node: Node):
75 await asyncio.sleep(1)