直接给出代码:

class DataPreFetcher(object):     def __init__(self, loader):         self.loader = iter(loader)         self.stream = torch.cuda.Stream()         self.preload()      def preload(self):         try:             self.next_data = next(self.loader)         except StopIteration:             self.next_data = None             return         with torch.cuda.stream(self.stream):             for k, v in self.next_data.items():                 if isinstance(v, torch.Tensor):                     self.next_data[k] = self.next_data[k].cuda(non_blocking=True)      def next(self):         torch.cuda.current_stream().wait_stream(self.stream)         data = self.next_data         self.preload()         return data 
class data_prefetcher():     def __init__(self, loader):         #loader 1:real         #loader 2:fake         self.stream = torch.cuda.Stream()         self.loader = iter(loader)         self.preload()       def preload(self):         try:             self.next_input, self.next_target = next(self.loader)         except StopIteration:             self.next_input = None             self.next_target = None             return         with torch.cuda.stream(self.stream):             self.next_input = self.next_input.cuda(non_blocking=True).float()             self.next_target = self.next_target.cuda(non_blocking=True).long()       def next(self):         torch.cuda.current_stream().wait_stream(self.stream)         input = self.next_input         target = self.next_target         self.preload()         return input, target 
class DataPreFetcher(object):     def __init__(self, dataLoader):         self.stream = torch.cuda.Stream()         self.dataLoader = iter(dataLoader)         self.preload()      def preload(self):         try:             self.next_batch_data = next(self.dataLoader)         except StopIteration:             self.next_batch_data = None             return         with torch.cuda.stream(self.stream):             for k, v in self.next_batch_data.items():                 if isinstance(v, torch.Tensor):                     self.next_batch_data[k] = self.next_batch_data[k].cuda(non_blocking=True)      def next(self):         torch.cuda.current_stream().wait_stream(self.stream)         batch_data = self.next_batch_data         self.preload()         return batch_data 



参考资料:
pytorch使用data_prefetcher提升数据读取速度
使用DataPrefetcher加速PyTorch的dataloader
【pytorch】给训练踩踩油门-- Pytorch 加速数据读取