#!/usr/bin/env python3 # -*- coding: utf-8 -*- # author: bt3gl class Node: def __init__(self, val, prev, next, child): self.val = val self.prev = prev self.next = next self.child = child def dfs(prev, node): if not node: return prev node.prev = prev prev.next = node temp_next = node.next last = dfs(node, node.child) node.child = None return dfs(last, temp_next) def flatten(head): if head is None: return head sentinel = Node(None, None, head, None) dfs(prev=sentinel, node=head) # erase the pointer to sentinel and return sentinel.next.prev = None return sentinel.next