Lei Zhou, Huidong Liu, Joseph Bae, Junjun He, Dimitris Samaras, Prateek Prasanna
{"title":"令牌稀疏化实现更快的医学图像分割","authors":"Lei Zhou, Huidong Liu, Joseph Bae, Junjun He, Dimitris Samaras, Prateek Prasanna","doi":"10.1007/978-3-031-34048-2_57","DOIUrl":null,"url":null,"abstract":"<p><p><i>Can we use sparse tokens for dense prediction, e.g., segmentation?</i> Although token sparsification has been applied to Vision Transformers (ViT) to accelerate classification, it is still unknown how to perform segmentation from sparse tokens. To this end, we reformulate segmentation as a <i>s</i><i>parse encoding</i> → <i>token</i> <i>c</i><i>ompletion</i> → <i>d</i><i>ense decoding</i> (SCD) pipeline. We first empirically show that naïvely applying existing approaches from classification token pruning and masked image modeling (MIM) leads to failure and inefficient training caused by inappropriate sampling algorithms and the low quality of the restored dense features. In this paper, we propose <i>Soft-topK Token Pruning (STP)</i> and <i>Multi-layer Token Assembly (MTA)</i> to address these problems. In <i>sparse encoding</i>, <i>STP</i> predicts token importance scores with a lightweight sub-network and samples the topK tokens. The intractable topK gradients are approximated through a continuous perturbed score distribution. In <i>token completion</i>, <i>MTA</i> restores a full token sequence by assembling both sparse output tokens and pruned multi-layer intermediate ones. The last <i>dense decoding</i> stage is compatible with existing segmentation decoders, e.g., UNETR. Experiments show SCD pipelines equipped with <i>STP</i> and <i>MTA</i> are much faster than baselines without token pruning in both training (up to 120% higher throughput) and inference (up to 60.6% higher throughput) while maintaining segmentation quality. Code is available here: https://github.com/cvlab-stonybrook/TokenSparse-for-MedSeg.</p>","PeriodicalId":73379,"journal":{"name":"Information processing in medical imaging : proceedings of the ... conference","volume":"13939 ","pages":"743-754"},"PeriodicalIF":0.0000,"publicationDate":"2023-06-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC11056020/pdf/","citationCount":"0","resultStr":"{\"title\":\"Token Sparsification for Faster Medical Image Segmentation.\",\"authors\":\"Lei Zhou, Huidong Liu, Joseph Bae, Junjun He, Dimitris Samaras, Prateek Prasanna\",\"doi\":\"10.1007/978-3-031-34048-2_57\",\"DOIUrl\":null,\"url\":null,\"abstract\":\"<p><p><i>Can we use sparse tokens for dense prediction, e.g., segmentation?</i> Although token sparsification has been applied to Vision Transformers (ViT) to accelerate classification, it is still unknown how to perform segmentation from sparse tokens. To this end, we reformulate segmentation as a <i>s</i><i>parse encoding</i> → <i>token</i> <i>c</i><i>ompletion</i> → <i>d</i><i>ense decoding</i> (SCD) pipeline. We first empirically show that naïvely applying existing approaches from classification token pruning and masked image modeling (MIM) leads to failure and inefficient training caused by inappropriate sampling algorithms and the low quality of the restored dense features. In this paper, we propose <i>Soft-topK Token Pruning (STP)</i> and <i>Multi-layer Token Assembly (MTA)</i> to address these problems. In <i>sparse encoding</i>, <i>STP</i> predicts token importance scores with a lightweight sub-network and samples the topK tokens. The intractable topK gradients are approximated through a continuous perturbed score distribution. In <i>token completion</i>, <i>MTA</i> restores a full token sequence by assembling both sparse output tokens and pruned multi-layer intermediate ones. The last <i>dense decoding</i> stage is compatible with existing segmentation decoders, e.g., UNETR. Experiments show SCD pipelines equipped with <i>STP</i> and <i>MTA</i> are much faster than baselines without token pruning in both training (up to 120% higher throughput) and inference (up to 60.6% higher throughput) while maintaining segmentation quality. Code is available here: https://github.com/cvlab-stonybrook/TokenSparse-for-MedSeg.</p>\",\"PeriodicalId\":73379,\"journal\":{\"name\":\"Information processing in medical imaging : proceedings of the ... conference\",\"volume\":\"13939 \",\"pages\":\"743-754\"},\"PeriodicalIF\":0.0000,\"publicationDate\":\"2023-06-01\",\"publicationTypes\":\"Journal Article\",\"fieldsOfStudy\":null,\"isOpenAccess\":false,\"openAccessPdf\":\"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC11056020/pdf/\",\"citationCount\":\"0\",\"resultStr\":null,\"platform\":\"Semanticscholar\",\"paperid\":null,\"PeriodicalName\":\"Information processing in medical imaging : proceedings of the ... conference\",\"FirstCategoryId\":\"1085\",\"ListUrlMain\":\"https://doi.org/10.1007/978-3-031-34048-2_57\",\"RegionNum\":0,\"RegionCategory\":null,\"ArticlePicture\":[],\"TitleCN\":null,\"AbstractTextCN\":null,\"PMCID\":null,\"EPubDate\":\"2023/6/8 0:00:00\",\"PubModel\":\"Epub\",\"JCR\":\"\",\"JCRName\":\"\",\"Score\":null,\"Total\":0}","platform":"Semanticscholar","paperid":null,"PeriodicalName":"Information processing in medical imaging : proceedings of the ... conference","FirstCategoryId":"1085","ListUrlMain":"https://doi.org/10.1007/978-3-031-34048-2_57","RegionNum":0,"RegionCategory":null,"ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":null,"EPubDate":"2023/6/8 0:00:00","PubModel":"Epub","JCR":"","JCRName":"","Score":null,"Total":0}
Token Sparsification for Faster Medical Image Segmentation.
Can we use sparse tokens for dense prediction, e.g., segmentation? Although token sparsification has been applied to Vision Transformers (ViT) to accelerate classification, it is still unknown how to perform segmentation from sparse tokens. To this end, we reformulate segmentation as a sparse encoding → tokencompletion → dense decoding (SCD) pipeline. We first empirically show that naïvely applying existing approaches from classification token pruning and masked image modeling (MIM) leads to failure and inefficient training caused by inappropriate sampling algorithms and the low quality of the restored dense features. In this paper, we propose Soft-topK Token Pruning (STP) and Multi-layer Token Assembly (MTA) to address these problems. In sparse encoding, STP predicts token importance scores with a lightweight sub-network and samples the topK tokens. The intractable topK gradients are approximated through a continuous perturbed score distribution. In token completion, MTA restores a full token sequence by assembling both sparse output tokens and pruned multi-layer intermediate ones. The last dense decoding stage is compatible with existing segmentation decoders, e.g., UNETR. Experiments show SCD pipelines equipped with STP and MTA are much faster than baselines without token pruning in both training (up to 120% higher throughput) and inference (up to 60.6% higher throughput) while maintaining segmentation quality. Code is available here: https://github.com/cvlab-stonybrook/TokenSparse-for-MedSeg.