pytorch로 만든 모델을 HiddenLayer 패키지를 사용해 시각화하려고 하는데, 'torch._C.Node' object is not subscriptable라는 오류가 발생했다.
torch 버전이 업그레이드 되면서 발생하는 오류 같았는데, 원래는 module 'torch.onnx' has no attribute '_optimize_trace'라는 오류에서부터 시작한다.
솔직히 말해서 나는 개발 직군이 아니기 때문에 조금 애를 쓰면서 고쳤다. 완벽하게 고치지는 못한 것 같고, 적당히 시각화할 수 있게끔 픽스했으니 원하시는 분들은 참고만 하시기를 바란다.
내가 사용하고 있는 torch 버전은 2.2.1+cu121이다.
1. 'torch.onnx' has no attribute '_optimize_trace'
C:\Users\user\anaconda3\Lib\site-packages\hiddenlayer (자세한 위치는 사람마다 다를 수 있음) 의 pytorch_builder.py 파일을 찾는다.
torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
이 부분에서 _optimize_trace를 _optimize_graph로 변경해주면 된다.
내 경우 오류가 여기서 끝나지 않았다.
2. 'torch._C.Node' object is not subscriptable
동일한 파일의 코드를 변경해 주면 된다.
변경을 필요로 하는 함수는 import_graph다. 아래와 같이 변경한다.
def import_graph(hl_graph, model, args, input_names=None, verbose=False):
# TODO: add input names to graph
# Run the Pytorch graph to get a trace and generate a graph from it
trace, out = torch.jit._get_trace_graph(model, args)
torch_graph = trace
# Dump list of nodes (DEBUG only)
if verbose:
dump_pytorch_graph(torch_graph)
# Loop through nodes and build HL graph
for torch_node in list(torch_graph.nodes()):
# Op
op = torch_node.kind()
# Parameters
params = {k: torch_node.kindOf(k) for k in torch_node.attributeNames()}
# Inputs/outputs
inputs = [i.unique() for i in torch_node.inputs()]
outputs = [o.unique() for o in torch_node.outputs()]
# Get output shape
shape = get_shape(torch_node)
# Add HL node
hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op,
output_shape=shape, params=params)
hl_graph.add_node(hl_node)
# Add edges
for target_torch_node in list(torch_graph.nodes()):
target_inputs = [i.unique() for i in target_torch_node.inputs()]
if set(outputs) & set(target_inputs):
hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
return hl_graph
torch_graph를 trace로 변경하고, 아래의 for문에 .nodes()를 추가해 주기 때문에 사실 1번과 같은 수정은 필요하지 않았다.
원래는 torch._C.jit_pass_inline(trace.graph)를 trace, out = torch.jit._get_trace_graph(model, args) 아래에 추가하는 방법도 있었다. 하지만 torch 버전이 업그레이드 되면서 코드가 'torch._C.Graph' object has no attribute 'graph' < 이런 오류를 뱉어낸다.
위와 같이 변경하면 적어도 시각화가 돌아가기는 하므로 실행시키고 싶으신 분들은 참고하세요...
참고로 예시와 같은 예쁜 블록이 나오지는 않는데, 왜 그런가 하고 코드를 뜯어보니 node의 default name이 None으로 되어 있고, None일 경우 node name으로 op (torch_node.kind())를 사용하고 있어서 아래와 같은 못생긴 블록이 출력된다. 아니면 나만 이런가? 잘 모르겠음... torch를 다운그레이드 해 봐야 싶기도 하고... 심지어 constant까지 보여주기 때문에 솔직히 HiddenLayer보다 torchviz가 적합한 것 같다.

못생긴 시각화
'공부 > Python' 카테고리의 다른 글
FATAL: no PostgreSQL user name specified in startup packet (0) | 2024.12.21 |
---|---|
Failed to import pytorch fbgemm.dll or one of its dependencies is missing. (0) | 2024.11.23 |