import%20marimo%0A%0A__generated_with%20%3D%20%220.10.9%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%23%20Displaying%20CNN%20Activations%20in%20Marimo%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20import%20json%0A%20%20%20%20import%20torch%0A%20%20%20%20import%20pandas%20as%20pd%0A%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%20%20%20%20from%20torchvision.io%20import%20decode_image%0A%20%20%20%20from%20torchvision.models%20import%20alexnet%2C%20AlexNet_Weights%0A%20%20%20%20from%20torchvision.transforms%20import%20v2%0A%20%20%20%20import%20plotly.graph_objects%20as%20go%0A%20%20%20%20from%20plotly.subplots%20import%20make_subplots%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20AlexNet_Weights%2C%0A%20%20%20%20%20%20%20%20alexnet%2C%0A%20%20%20%20%20%20%20%20decode_image%2C%0A%20%20%20%20%20%20%20%20go%2C%0A%20%20%20%20%20%20%20%20json%2C%0A%20%20%20%20%20%20%20%20make_subplots%2C%0A%20%20%20%20%20%20%20%20mo%2C%0A%20%20%20%20%20%20%20%20pd%2C%0A%20%20%20%20%20%20%20%20plt%2C%0A%20%20%20%20%20%20%20%20torch%2C%0A%20%20%20%20%20%20%20%20v2%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20image_selection%20%3D%20mo.ui.dropdown(%5B%22french_loaf.jpg%22%2C%20%22maltese.jpeg%22%2C%20%22snail.jpeg%22%5D%2C%20value%3D%22snail.jpeg%22)%0A%20%20%20%20image_selection%0A%20%20%20%20return%20(image_selection%2C)%0A%0A%0A%40app.cell%0Adef%20_(image_selection%2C%20mo)%3A%0A%20%20%20%20image_path%20%3D%20f%22data%2F%7Bimage_selection.value%7D%22%0A%0A%20%20%20%20mo.image(src%3Dimage_path)%0A%20%20%20%20return%20(image_path%2C)%0A%0A%0A%40app.cell%0Adef%20_(activations%2C%20go%2C%20image_selection%2C%20labels%2C%20make_subplots)%3A%0A%20%20%20%20preds%20%3D%20activations%5B%22classification%22%5D.flatten()%0A%0A%20%20%20%20top_preds%20%3D%20preds.argsort()%5B-5%3A%5D.tolist()%0A%20%20%20%20fig%20%3D%20make_subplots()%0A%20%20%20%20fig.add_trace(go.Bar(x%3Dlist(range(len(preds)))%2C%20y%3Dpreds))%0A%20%20%20%20fig.add_trace(go.Scatter(x%3Dtop_preds%2C%20y%3Dpreds%5Btop_preds%5D%2C%20text%3D%5Blabels%5Bstr(i)%5D%20for%20i%20in%20top_preds%5D%2C%20mode%3D%22markers%2Btext%22%2C%20textposition%3D%22top%20center%22))%0A%20%20%20%20fig.update_layout(title_text%3Df%22Top-5%20Class%20predictions%20for%20%7Bimage_selection.value%7D%22)%0A%20%20%20%20return%20fig%2C%20preds%2C%20top_preds%0A%0A%0A%40app.cell%0Adef%20_(AlexNet_Weights%2C%20alexnet)%3A%0A%20%20%20%20%23%20Initialize%20the%20AlexNet%20model%20with%20weights%2C%20trained%20on%20ImageNet%0A%20%20%20%20model%20%3D%20alexnet(weights%3DAlexNet_Weights.DEFAULT)%0A%0A%20%20%20%20%23%20Set%20the%20model%20into%20inference%20mode%0A%20%20%20%20model%20%3D%20model.eval()%0A%20%20%20%20return%20(model%2C)%0A%0A%0A%40app.cell%0Adef%20_(model)%3A%0A%20%20%20%20model%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%20Adding%20forward%20hooks%20to%20AlexNet%0A%0A%20%20%20%20%20%20%20%20To%20extract%20the%20activations%20of%20a%20particular%20layer%2C%20we%20need%20to%20intercept%20the%20information%20flow%20of%20the%20network.%0A%20%20%20%20%20%20%20%20Forward%20hooks%20allow%20us%20to%20do%20exactly%20that.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(decode_image%2C%20get_activation%2C%20image_path%2C%20model%2C%20preprocess_image)%3A%0A%20%20%20%20activations%20%3D%20%7B%7D%0A%0A%20%20%20%20model.avgpool.register_forward_hook(get_activation(%22avgpool%22%2C%20activations))%0A%20%20%20%20model.features%5B2%5D.register_forward_hook(get_activation(%22max_pool1%22%2C%20activations))%0A%20%20%20%20model.features%5B5%5D.register_forward_hook(get_activation(%22max_pool2%22%2C%20activations))%0A%20%20%20%20model.classifier%5B6%5D.register_forward_hook(get_activation(%22classification%22%2C%20activations))%0A%0A%20%20%20%20img%20%3D%20preprocess_image(decode_image(image_path))%0A%0A%20%20%20%20_%20%3D%20model.forward(img.unsqueeze(0))%0A%20%20%20%20return%20activations%2C%20img%0A%0A%0A%40app.cell%0Adef%20_(activations%2C%20image_selection%2C%20plt)%3A%0A%20%20%20%20def%20plot_activations(activations)%3A%0A%20%20%20%20%20%20%20%20fig%2C%20ax%20%3D%20plt.subplots(nrows%3D1%2C%20ncols%3D3%2C%20figsize%3D(12%2C%205))%0A%0A%20%20%20%20%20%20%20%20ax%5B0%5D.imshow(activations%5B%22max_pool1%22%5D.squeeze().reshape(64%2C%20729))%0A%20%20%20%20%20%20%20%20ax%5B0%5D.set_title(%22Max%20Pooling%201%22)%0A%20%20%20%20%20%20%20%20ax%5B1%5D.imshow(activations%5B%22max_pool2%22%5D.squeeze().reshape(192%2C%20169))%0A%20%20%20%20%20%20%20%20ax%5B1%5D.set_title(%22Max%20Pooling%202%22)%0A%20%20%20%20%20%20%20%20ax%5B2%5D.imshow(activations%5B%22avgpool%22%5D.squeeze().reshape(256%2C%2036))%0A%20%20%20%20%20%20%20%20ax%5B2%5D.set_title(%22Average%20Pooling%22)%0A%0A%20%20%20%20%20%20%20%20fig.suptitle(f%22Layer%20Activations%20for%20%7Bimage_selection.value%7D%22)%0A%0A%20%20%20%20%20%20%20%20plt.show()%0A%0A%20%20%20%20plot_activations(activations)%0A%20%20%20%20return%20(plot_activations%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%23%23%20Helper%20Functions%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(AlexNet_Weights%2C%20json%2C%20v2)%3A%0A%20%20%20%20preprocess_image%20%3D%20v2.Compose(%5B%0A%20%20%20%20%20%20%20%20AlexNet_Weights.IMAGENET1K_V1.transforms()%0A%20%20%20%20%5D)%0A%0A%20%20%20%20def%20get_activation(name%3A%20str%2C%20activations)%3A%0A%20%20%20%20%20%20%23%20the%20hook%20signature%0A%20%20%20%20%20%20def%20hook(model%2C%20input%2C%20output)%3A%0A%20%20%20%20%20%20%20%20activations%5Bname%5D%20%3D%20output.detach()%0A%20%20%20%20%20%20return%20hook%0A%0A%20%20%20%20def%20get_imagenet_labels(path%3A%20str%20%3D%20%22data%2Fimagenet_labels.json%22)%20-%3E%20dict%5Bint%2C%20str%5D%3A%0A%20%20%20%20%20%20%20%20with%20open(path)%20as%20f%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20json.load(f)%0A%0A%20%20%20%20labels%20%3D%20get_imagenet_labels()%0A%20%20%20%20return%20get_activation%2C%20get_imagenet_labels%2C%20labels%2C%20preprocess_image%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
ae0d1595f9b3d72c4a7b38bbd387cca38ddf97ed05cab4aa6568ceb27be67f74