Sahil Sethi, David Chen, Thomas Statchen, Michael C Burkhart, Nipun Bhandari, Bashar Ramadan, Brett Beaulieu-Jones
Deep learning-based electrocardiogram (ECG) classification has shown impressive performance but clinical adoption has been slowed by the lack of transparent and faithful explanations. Post hoc methods such as saliency maps may fail to reflect a model's true decision process. Prototype-based reasoning offers a more transparent alternative by grounding decisions in similarity to learned representations of real ECG segments-enabling faithful, case-based explanations. We introduce ProtoECGNet, a prototype-based deep learning model for interpretable, multi-label ECG classification. ProtoECGNet employs a structured, multi-branch architecture that reflects clinical interpretation workflows: it integrates a 1D CNN with global prototypes for rhythm classification, a 2D CNN with time-localized prototypes for morphology-based reasoning, and a 2D CNN with global prototypes for diffuse abnormalities. Each branch is trained with a prototype loss designed for multi-label learning, combining clustering, separation, diversity, and a novel contrastive loss that encourages appropriate separation between prototypes of unrelated classes while allowing clustering for frequently co-occurring diagnoses. We evaluate ProtoECGNet on all 71 labels from the PTB-XL dataset, demonstrating competitive performance relative to state-of-the-art black-box models while providing structured, case-based explanations. To assess prototype quality, we conduct a structured clinician review of the final model's projected prototypes, finding that they are rated as representative and clear. ProtoECGNet shows that prototype learning can be effectively scaled to complex, multi-label time-series classification, offering a practical path toward transparent and trustworthy deep learning models for clinical decision support.
{"title":"ProtoECGNet: Case-Based Interpretable Deep Learning for Multi-Label ECG Classification with Contrastive Learning.","authors":"Sahil Sethi, David Chen, Thomas Statchen, Michael C Burkhart, Nipun Bhandari, Bashar Ramadan, Brett Beaulieu-Jones","doi":"","DOIUrl":"","url":null,"abstract":"<p><p>Deep learning-based electrocardiogram (ECG) classification has shown impressive performance but clinical adoption has been slowed by the lack of transparent and faithful explanations. Post hoc methods such as saliency maps may fail to reflect a model's true decision process. Prototype-based reasoning offers a more transparent alternative by grounding decisions in similarity to learned representations of real ECG segments-enabling faithful, case-based explanations. We introduce ProtoECGNet, a prototype-based deep learning model for interpretable, multi-label ECG classification. ProtoECGNet employs a structured, multi-branch architecture that reflects clinical interpretation workflows: it integrates a 1D CNN with global prototypes for rhythm classification, a 2D CNN with time-localized prototypes for morphology-based reasoning, and a 2D CNN with global prototypes for diffuse abnormalities. Each branch is trained with a prototype loss designed for multi-label learning, combining clustering, separation, diversity, and a novel contrastive loss that encourages appropriate separation between prototypes of unrelated classes while allowing clustering for frequently co-occurring diagnoses. We evaluate ProtoECGNet on all 71 labels from the PTB-XL dataset, demonstrating competitive performance relative to state-of-the-art black-box models while providing structured, case-based explanations. To assess prototype quality, we conduct a structured clinician review of the final model's projected prototypes, finding that they are rated as representative and clear. ProtoECGNet shows that prototype learning can be effectively scaled to complex, multi-label time-series classification, offering a practical path toward transparent and trustworthy deep learning models for clinical decision support.</p>","PeriodicalId":74504,"journal":{"name":"Proceedings of machine learning research","volume":"298 ","pages":""},"PeriodicalIF":0.0,"publicationDate":"2025-08-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC12700622/pdf/","citationCount":null,"resultStr":null,"platform":"Semanticscholar","paperid":"145758662","PeriodicalName":null,"FirstCategoryId":null,"ListUrlMain":null,"RegionNum":0,"RegionCategory":"","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":"OA","EPubDate":null,"PubModel":null,"JCR":null,"JCRName":null,"Score":null,"Total":0}
The Hawkes process (HP) is commonly used to model event sequences with self-reinforcing dynamics, including electronic health records (EHRs). Traditional HPs capture self-reinforcement via parametric impact functions that can be inspected to understand how each event modulates the intensity of others. Neural network-based HPs offer greater flexibility, resulting in improved fit and prediction performance, but at the cost of interpretability, which is often critical in healthcare. In this work, we aim to understand and improve upon this tradeoff. We propose a novel HP formulation in which impact functions are modeled by defining a flexible impact kernel, instantiated as a neural network, in event embedding space, which allows us to model large-scale event sequences with many event types. This approach is more flexible than traditional HPs yet more interpretable than other neural network approaches, and allows us to explicitly trade flexibility for interpretability by adding transformer encoder layers to further contextualize the event embeddings. Results show that our method accurately recovers impact functions in simulations, achieves competitive performance on MIMIC-IV procedure dataset, and gains clinically meaningful interpretation on Duke-EHR with children diagnosis dataset even without transformer layers. This suggests that our flexible impact kernel is often sufficient to capture self-reinforcing dynamics in EHRs and other data effectively, implying that interpretability can be maintained without loss of performance.
{"title":"Balancing Interpretability and Flexibility in Modeling Diagnostic Trajectories with an Embedded Neural Hawkes Process Model.","authors":"Yuankang Zhao, Matthew M Engelhard","doi":"","DOIUrl":"","url":null,"abstract":"<p><p>The Hawkes process (HP) is commonly used to model event sequences with self-reinforcing dynamics, including electronic health records (EHRs). Traditional HPs capture self-reinforcement via parametric impact functions that can be inspected to understand how each event modulates the intensity of others. Neural network-based HPs offer greater flexibility, resulting in improved fit and prediction performance, but at the cost of interpretability, which is often critical in healthcare. In this work, we aim to understand and improve upon this tradeoff. We propose a novel HP formulation in which impact functions are modeled by defining a flexible impact kernel, instantiated as a neural network, in event embedding space, which allows us to model large-scale event sequences with many event types. This approach is more flexible than traditional HPs yet more interpretable than other neural network approaches, and allows us to explicitly trade flexibility for interpretability by adding transformer encoder layers to further contextualize the event embeddings. Results show that our method accurately recovers impact functions in simulations, achieves competitive performance on MIMIC-IV procedure dataset, and gains clinically meaningful interpretation on Duke-EHR with children diagnosis dataset even without transformer layers. This suggests that our flexible impact kernel is often sufficient to capture self-reinforcing dynamics in EHRs and other data effectively, implying that interpretability can be maintained without loss of performance.</p>","PeriodicalId":74504,"journal":{"name":"Proceedings of machine learning research","volume":"298 ","pages":""},"PeriodicalIF":0.0,"publicationDate":"2025-08-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC12646569/pdf/","citationCount":null,"resultStr":null,"platform":"Semanticscholar","paperid":"145643662","PeriodicalName":null,"FirstCategoryId":null,"ListUrlMain":null,"RegionNum":0,"RegionCategory":"","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":"OA","EPubDate":null,"PubModel":null,"JCR":null,"JCRName":null,"Score":null,"Total":0}
Guilherme Seidyo Imai Aldeia, Daniel S Herman, William G La Cava
Large language models (LLMs) have demonstrated remarkable capabilities for medical question answering and programming, but their potential for generating interpretable computable phenotypes (CPs) is under-explored. In this work, we investigate whether LLMs can generate accurate and concise CPs for six clinical phenotypes of varying complexity, which could be leveraged to enable scalable clinical decision support to improve care for patients with hypertension. In addition to evaluating zero-short performance, we propose and test a synthesize, execute, debug, instruct strategy that uses LLMs to generate and iteratively refine CPs using data-driven feedback. Our results show that LLMs, coupled with iterative learning, can generate interpretable and reasonably accurate programs that approach the performance of state-of-the-art ML methods while requiring significantly fewer training examples.
{"title":"Iterative Learning of Computable Phenotypes for Treatment Resistant Hypertension using Large Language Models.","authors":"Guilherme Seidyo Imai Aldeia, Daniel S Herman, William G La Cava","doi":"","DOIUrl":"","url":null,"abstract":"<p><p>Large language models (LLMs) have demonstrated remarkable capabilities for medical question answering and programming, but their potential for generating interpretable computable phenotypes (CPs) is under-explored. In this work, we investigate whether LLMs can generate accurate and concise CPs for six clinical phenotypes of varying complexity, which could be leveraged to enable scalable clinical decision support to improve care for patients with hypertension. In addition to evaluating zero-short performance, we propose and test a <i>synthesize, execute, debug, instruct</i> strategy that uses LLMs to generate and iteratively refine CPs using data-driven feedback. Our results show that LLMs, coupled with iterative learning, can generate interpretable and reasonably accurate programs that approach the performance of state-of-the-art ML methods while requiring significantly fewer training examples.</p>","PeriodicalId":74504,"journal":{"name":"Proceedings of machine learning research","volume":"298 ","pages":""},"PeriodicalIF":0.0,"publicationDate":"2025-08-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC12755843/pdf/","citationCount":null,"resultStr":null,"platform":"Semanticscholar","paperid":"145890622","PeriodicalName":null,"FirstCategoryId":null,"ListUrlMain":null,"RegionNum":0,"RegionCategory":"","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":"OA","EPubDate":null,"PubModel":null,"JCR":null,"JCRName":null,"Score":null,"Total":0}
Minghui Sun, Matthew M Engelhard, Benjamin A Goldstein
Risk assessments for a pediatric population are often conducted across multiple stages. For example, clinicians may evaluate risks prenatally, at birth, and during WellChild visits. While predictions at later stages typically achieve higher accuracy, it is clinically desirable to make reliable risk assessments as early as possible. Therefore, this study focuses on enhancing prediction performance in early-stage risk assessments. Our solution, Borrowing From the Future (BFF), is a contrastive multi-modal framework that treats each time window as a distinct modality. In BFF, a model is trained on all available data throughout the time while conduct risk assessment using the up-to-time information. This contrastive framework allows the model to "borrow" informative signals from later stages (e.g., WellChild visits) to implicitly supervise the learning at earlier stages (e.g., prenatal/birth stages). We validate BFF on two real-world pediatric outcome prediction tasks, demonstrating consistent improvements in early risk assessment. The code is at https://github.com/scotsun/bff.
{"title":"Borrowing From the Future: Enhancing Early Risk Assessment through Contrastive Learning.","authors":"Minghui Sun, Matthew M Engelhard, Benjamin A Goldstein","doi":"","DOIUrl":"","url":null,"abstract":"<p><p>Risk assessments for a pediatric population are often conducted across multiple stages. For example, clinicians may evaluate risks prenatally, at birth, and during WellChild visits. While predictions at later stages typically achieve higher accuracy, it is clinically desirable to make reliable risk assessments as early as possible. Therefore, this study focuses on enhancing prediction performance in early-stage risk assessments. Our solution, <b>Borrowing From the Future (BFF)</b>, is a contrastive multi-modal framework that treats each time window as a distinct modality. In BFF, a model is trained on all available data throughout the time while conduct risk assessment using the up-to-time information. This contrastive framework allows the model to \"borrow\" informative signals from later stages (e.g., WellChild visits) to implicitly supervise the learning at earlier stages (e.g., prenatal/birth stages). We validate BFF on two real-world pediatric outcome prediction tasks, demonstrating consistent improvements in early risk assessment. The code is at https://github.com/scotsun/bff.</p>","PeriodicalId":74504,"journal":{"name":"Proceedings of machine learning research","volume":"298 ","pages":""},"PeriodicalIF":0.0,"publicationDate":"2025-08-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC12646567/pdf/","citationCount":null,"resultStr":null,"platform":"Semanticscholar","paperid":"145642974","PeriodicalName":null,"FirstCategoryId":null,"ListUrlMain":null,"RegionNum":0,"RegionCategory":"","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":"OA","EPubDate":null,"PubModel":null,"JCR":null,"JCRName":null,"Score":null,"Total":0}
Harvineet Singh, Fan Xia, Alexej Gossmann, Andrew Chuang, Julian C Hong, Jean Feng
Machine learning (ML) models frequently experience performance degradation when deployed in new contexts. Such degradation is rarely uniform: some subgroups may suffer large performance decay while others may not. Understanding where and how large differences in performance arise is critical for designing targeted corrective actions that mitigate decay for the most affected subgroups while minimizing any unintended effects. Current approaches do not provide such detailed insight, as they either (i) explain how average performance shifts arise or (ii) identify adversely affected subgroups without insight into how this occurred. To this end, we introduce a Subgroup-scanning Hierarchical Inference Framework for performance drifT (SHIFT). SHIFT first asks "Is there any subgroup with unacceptably large performance decay due to covariate/outcome shifts?" (Where?) and, if so, dives deeper to ask "Can we explain this using more detailed variable(subset)-specific shifts?" (How?). In real-world experiments, we find that SHIFT identifies interpretable subgroups affected by performance decay, and suggests targeted actions that effectively mitigate the decay.
{"title":"\"Who experiences large model decay and why?\" A Hierarchical Framework for Diagnosing Heterogeneous Performance Drift.","authors":"Harvineet Singh, Fan Xia, Alexej Gossmann, Andrew Chuang, Julian C Hong, Jean Feng","doi":"","DOIUrl":"","url":null,"abstract":"<p><p>Machine learning (ML) models frequently experience performance degradation when deployed in new contexts. Such degradation is rarely uniform: some subgroups may suffer large performance decay while others may not. Understanding where and how large differences in performance arise is critical for designing <i>targeted</i> corrective actions that mitigate decay for the most affected subgroups while minimizing any unintended effects. Current approaches do not provide such detailed insight, as they either (i) explain how <i>average</i> performance shifts arise or (ii) identify adversely affected subgroups without insight into how this occurred. To this end, we introduce a <b>S</b>ubgroup-scanning <b>H</b>ierarchical <b>I</b>nference <b>F</b>ramework for performance drif<b>T</b> (SHIFT). SHIFT first asks \"Is there any subgroup with unacceptably large performance decay due to covariate/outcome shifts?\" (<i>Where?</i>) and, if so, dives deeper to ask \"Can we explain this using more detailed variable(subset)-specific shifts?\" (<i>How?</i>). In real-world experiments, we find that SHIFT identifies interpretable subgroups affected by performance decay, and suggests targeted actions that effectively mitigate the decay.</p>","PeriodicalId":74504,"journal":{"name":"Proceedings of machine learning research","volume":"267 ","pages":"55757-55787"},"PeriodicalIF":0.0,"publicationDate":"2025-07-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC12747154/pdf/","citationCount":null,"resultStr":null,"platform":"Semanticscholar","paperid":"145866889","PeriodicalName":null,"FirstCategoryId":null,"ListUrlMain":null,"RegionNum":0,"RegionCategory":"","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":"OA","EPubDate":null,"PubModel":null,"JCR":null,"JCRName":null,"Score":null,"Total":0}
High-dimensional imaging of neural activity, such as widefield calcium and functional ultrasound imaging, provide a rich source of information for understanding the relationship between brain activity and behavior. Accurately modeling neural dynamics in these modalities is crucial for understanding this relationship but is hindered by the high-dimensionality, complex spatiotemporal dependencies, and prevalent behaviorally irrelevant dynamics in these modalities. Existing dynamical models often employ preprocessing steps to obtain low-dimensional representations from neural image modalities. However, this process can discard behaviorally relevant information and miss spatiotemporal structure. We propose SBIND, a novel data-driven deep learning framework to model spatiotemporal dependencies in neural images and disentangle their behaviorally relevant dynamics from other neural dynamics. We validate SBIND on widefield imaging datasets, and show its extension to functional ultrasound imaging, a recent modality whose dynamical modeling has largely remained unexplored. We find that our model effectively identifies both local and long-range spatial dependencies across the brain while also dissociating behaviorally relevant neural dynamics. Doing so, SBIND outperforms existing models in neural-behavioral prediction. Overall, SBIND provides a versatile tool for investigating the neural mechanisms underlying behavior using imaging modalities.
{"title":"Dynamical Modeling of Behaviorally Relevant Spatiotemporal Patterns in Neural Imaging Data.","authors":"Mohammad Hosseini, Maryam M Shanechi","doi":"","DOIUrl":"","url":null,"abstract":"<p><p>High-dimensional imaging of neural activity, such as widefield calcium and functional ultrasound imaging, provide a rich source of information for understanding the relationship between brain activity and behavior. Accurately modeling neural dynamics in these modalities is crucial for understanding this relationship but is hindered by the high-dimensionality, complex spatiotemporal dependencies, and prevalent behaviorally irrelevant dynamics in these modalities. Existing dynamical models often employ preprocessing steps to obtain low-dimensional representations from neural image modalities. However, this process can discard behaviorally relevant information and miss spatiotemporal structure. We propose SBIND, a novel data-driven deep learning framework to model spatiotemporal dependencies in neural images and disentangle their behaviorally relevant dynamics from other neural dynamics. We validate SBIND on widefield imaging datasets, and show its extension to functional ultrasound imaging, a recent modality whose dynamical modeling has largely remained unexplored. We find that our model effectively identifies both local and long-range spatial dependencies across the brain while also dissociating behaviorally relevant neural dynamics. Doing so, SBIND outperforms existing models in neural-behavioral prediction. Overall, SBIND provides a versatile tool for investigating the neural mechanisms underlying behavior using imaging modalities.</p>","PeriodicalId":74504,"journal":{"name":"Proceedings of machine learning research","volume":"267 ","pages":"23846-23872"},"PeriodicalIF":0.0,"publicationDate":"2025-07-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC12662753/pdf/","citationCount":null,"resultStr":null,"platform":"Semanticscholar","paperid":"145650400","PeriodicalName":null,"FirstCategoryId":null,"ListUrlMain":null,"RegionNum":0,"RegionCategory":"","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":"OA","EPubDate":null,"PubModel":null,"JCR":null,"JCRName":null,"Score":null,"Total":0}
Halil Alperen Gozeten, M Emrullah Ildiz, Xuechen Zhang, Mahdi Soltanolkotabi, Marco Mondelli, Samet Oymak
Test-time training (TTT) methods explicitly update the weights of a model to adapt to the specific test instance, and they have found success in a variety of settings, including most recently language modeling and reasoning. To demystify this success, we investigate a gradient-based TTT algorithm for in-context learning, where we train a transformer model on the in-context demonstrations provided in the test prompt. Specifically, we provide a comprehensive theoretical characterization of linear transformers when the update rule is a single gradient step. Our theory (i) delineates the role of alignment between pretraining distribution and target task, (ii) demystifies how TTT can alleviate distribution shift, and (iii) quantifies the sample complexity of TTT including how it can significantly reduce the eventual sample size required for in-context learning. As our empirical contribution, we study the benefits of TTT for TabPFN, a tabular foundation model. In line with our theory, we demonstrate that TTT significantly reduces the required sample size for tabular classification (3 to 5 times fewer) unlocking substantial inference efficiency with a negligible training cost.
{"title":"Test-Time Training Provably Improves Transformers as In-context Learners.","authors":"Halil Alperen Gozeten, M Emrullah Ildiz, Xuechen Zhang, Mahdi Soltanolkotabi, Marco Mondelli, Samet Oymak","doi":"","DOIUrl":"","url":null,"abstract":"<p><p>Test-time training (TTT) methods explicitly update the weights of a model to adapt to the specific test instance, and they have found success in a variety of settings, including most recently language modeling and reasoning. To demystify this success, we investigate a gradient-based TTT algorithm for in-context learning, where we train a transformer model on the in-context demonstrations provided in the test prompt. Specifically, we provide a comprehensive theoretical characterization of linear transformers when the update rule is a single gradient step. Our theory <i>(i)</i> delineates the role of alignment between pretraining distribution and target task, <i>(ii)</i> demystifies how TTT can alleviate distribution shift, and <i>(iii)</i> quantifies the sample complexity of TTT including how it can significantly reduce the eventual sample size required for in-context learning. As our empirical contribution, we study the benefits of TTT for TabPFN, a tabular foundation model. In line with our theory, we demonstrate that TTT significantly reduces the required sample size for tabular classification (3 to 5 times fewer) unlocking substantial inference efficiency with a negligible training cost.</p>","PeriodicalId":74504,"journal":{"name":"Proceedings of machine learning research","volume":"267 ","pages":"20266-20295"},"PeriodicalIF":0.0,"publicationDate":"2025-07-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC12662752/pdf/","citationCount":null,"resultStr":null,"platform":"Semanticscholar","paperid":"145650398","PeriodicalName":null,"FirstCategoryId":null,"ListUrlMain":null,"RegionNum":0,"RegionCategory":"","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":"OA","EPubDate":null,"PubModel":null,"JCR":null,"JCRName":null,"Score":null,"Total":0}
Probabilistic survival analysis models seek to estimate the distribution of the future occurrence (time) of an event given a set of covariates. In recent years, these models have preferred nonparametric specifications that avoid directly estimating survival distributions via discretization. Specifically, they estimate the probability of an individual event at fixed times or the time of an event at fixed probabilities (quantiles), using supervised learning. Borrowing ideas from the quantile regression literature, we propose a parametric survival analysis method based on the Asymmetric Laplace Distribution (ALD). This distribution allows for closed-form calculation of popular event summaries such as mean, median, mode, variation, and quantiles. The model is optimized by maximum likelihood to learn, at the individual level, the parameters (location, scale, and asymmetry) of the ALD distribution. Extensive results on synthetic and real-world data demonstrate that the proposed method outperforms parametric and nonparametric approaches in terms of accuracy, discrimination and calibration.
{"title":"Learning Survival Distributions with the Asymmetric Laplace Distribution.","authors":"Deming Sheng, Ricardo Henao","doi":"","DOIUrl":"","url":null,"abstract":"<p><p>Probabilistic survival analysis models seek to estimate the distribution of the future occurrence (time) of an event given a set of covariates. In recent years, these models have preferred nonparametric specifications that avoid directly estimating survival distributions via discretization. Specifically, they estimate the probability of an individual event at fixed times or the time of an event at fixed probabilities (quantiles), using supervised learning. Borrowing ideas from the quantile regression literature, we propose a parametric survival analysis method based on the Asymmetric Laplace Distribution (ALD). This distribution allows for closed-form calculation of popular event summaries such as mean, median, mode, variation, and quantiles. The model is optimized by maximum likelihood to learn, at the individual level, the parameters (location, scale, and asymmetry) of the ALD distribution. Extensive results on synthetic and real-world data demonstrate that the proposed method outperforms parametric and nonparametric approaches in terms of accuracy, discrimination and calibration.</p>","PeriodicalId":74504,"journal":{"name":"Proceedings of machine learning research","volume":"267 ","pages":"54772-54809"},"PeriodicalIF":0.0,"publicationDate":"2025-07-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC12669900/pdf/","citationCount":null,"resultStr":null,"platform":"Semanticscholar","paperid":"145673129","PeriodicalName":null,"FirstCategoryId":null,"ListUrlMain":null,"RegionNum":0,"RegionCategory":"","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":"OA","EPubDate":null,"PubModel":null,"JCR":null,"JCRName":null,"Score":null,"Total":0}
Talal Widatalla, Richard W Shuai, Brian L Hie, Po-Ssu Huang
Leading deep learning-based methods for fixed-backbone protein sequence design do not model protein sidechain conformation during sequence generation despite the large role the three-dimensional arrangement of sidechain atoms play in protein conformation, stability, and overall protein function. Instead, these models implicitly reason about crucial sidechain interactions based on backbone geometry and known amino acid sequence labels. To address this, we present FAMPNN (Full-Atom MPNN), a sequence design method that explicitly models both sequence identity and sidechain conformation for each residue, where the per-token distribution of a residue's discrete amino acid identity and its continuous sidechain conformation are learned with a combined categorical cross-entropy and diffusion loss objective. We demonstrate that learning these distributions jointly is a highly synergistic task that both improves sequence recovery while achieving state-of-the-art sidechain packing. Furthermore, benefits from full-atom modeling generalize from sequence recovery to practical protein design applications, such as zero-shot prediction of experimental binding and stability measurements.
{"title":"Sidechain conditioning and modeling for full-atom protein sequence design with FAMPNN.","authors":"Talal Widatalla, Richard W Shuai, Brian L Hie, Po-Ssu Huang","doi":"","DOIUrl":"","url":null,"abstract":"<p><p>Leading deep learning-based methods for fixed-backbone protein sequence design do not model protein sidechain conformation during sequence generation despite the large role the three-dimensional arrangement of sidechain atoms play in protein conformation, stability, and overall protein function. Instead, these models implicitly reason about crucial sidechain interactions based on backbone geometry and known amino acid sequence labels. To address this, we present FAMPNN (Full-Atom MPNN), a sequence design method that explicitly models both sequence identity and sidechain conformation for each residue, where the per-token distribution of a residue's discrete amino acid identity and its continuous sidechain conformation are learned with a combined categorical cross-entropy and diffusion loss objective. We demonstrate that learning these distributions jointly is a highly synergistic task that both improves sequence recovery while achieving state-of-the-art sidechain packing. Furthermore, benefits from full-atom modeling generalize from sequence recovery to practical protein design applications, such as zero-shot prediction of experimental binding and stability measurements.</p>","PeriodicalId":74504,"journal":{"name":"Proceedings of machine learning research","volume":"267 ","pages":"66746-66771"},"PeriodicalIF":0.0,"publicationDate":"2025-07-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC12646570/pdf/","citationCount":null,"resultStr":null,"platform":"Semanticscholar","paperid":"145643666","PeriodicalName":null,"FirstCategoryId":null,"ListUrlMain":null,"RegionNum":0,"RegionCategory":"","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":"OA","EPubDate":null,"PubModel":null,"JCR":null,"JCRName":null,"Score":null,"Total":0}
A Bayesian coreset is a small, weighted subset of a data set that replaces the full data during inference to reduce computational cost. The state-of-the-art coreset construction algorithm, Coreset Markov chain Monte Carlo (Coreset MCMC), uses draws from an adaptive Markov chain targeting the coreset posterior to train the coreset weights via stochastic gradient optimization. However, the quality of the constructed coreset, and thus the quality of its posterior approximation, is sensitive to the stochastic optimization learning rate. In this work, we propose a learning-rate-free stochastic gradient optimization procedure, Hot-start Distance over Gradient (Hot DoG), for training coreset weights in Coreset MCMC without user tuning effort. We provide a theoretical analysis of the convergence of the coreset weights produced by Hot DoG. We also provide empirical results demonstrate that Hot DoG provides higher quality posterior approximations than other learning-rate-free stochastic gradient methods, and performs competitively to optimally-tuned ADAM.
贝叶斯核心集是数据集的一个小的加权子集,在推理过程中取代完整的数据以减少计算成本。最先进的核心集构建算法,coreset Markov chain Monte Carlo (coreset MCMC),使用自适应的Markov链,通过随机梯度优化来训练核心集的权重。然而,构建的核心集的质量及其后验逼近的质量对随机优化学习率很敏感。在这项工作中,我们提出了一个无学习率的随机梯度优化过程,即Hot-start Distance over gradient (Hot DoG),用于在coreset MCMC中训练核心集权重,而无需用户进行调整。我们对Hot DoG产生的核心权值的收敛性进行了理论分析。我们还提供了实证结果,表明Hot DoG提供了比其他无学习率随机梯度方法更高质量的后验逼近,并且与最优调谐的ADAM相比具有竞争力。
{"title":"Tuning-Free Coreset Markov Chain Monte Carlo via Hot DoG.","authors":"Naitong Chen, Jonathan H Huggins, Trevor Campbell","doi":"","DOIUrl":"","url":null,"abstract":"<p><p>A Bayesian coreset is a small, weighted subset of a data set that replaces the full data during inference to reduce computational cost. The state-of-the-art coreset construction algorithm, <i>Coreset Markov chain Monte Carlo</i> (Coreset MCMC), uses draws from an adaptive Markov chain targeting the coreset posterior to train the coreset weights via stochastic gradient optimization. However, the quality of the constructed coreset, and thus the quality of its posterior approximation, is sensitive to the stochastic optimization learning rate. In this work, we propose a learning-rate-free stochastic gradient optimization procedure, <i>Hot-start Distance over Gradient</i> (Hot DoG), for training coreset weights in Coreset MCMC without user tuning effort. We provide a theoretical analysis of the convergence of the coreset weights produced by Hot DoG. We also provide empirical results demonstrate that Hot DoG provides higher quality posterior approximations than other learning-rate-free stochastic gradient methods, and performs competitively to optimally-tuned ADAM.</p>","PeriodicalId":74504,"journal":{"name":"Proceedings of machine learning research","volume":"286 ","pages":"647-672"},"PeriodicalIF":0.0,"publicationDate":"2025-07-01","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC12704252/pdf/","citationCount":null,"resultStr":null,"platform":"Semanticscholar","paperid":"145770231","PeriodicalName":null,"FirstCategoryId":null,"ListUrlMain":null,"RegionNum":0,"RegionCategory":"","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":"OA","EPubDate":null,"PubModel":null,"JCR":null,"JCRName":null,"Score":null,"Total":0}