FFCV:讓數(shù)據(jù)加載不再是訓練模型的瓶頸
前段時間逛GitHub看到FFCV這個庫,該庫主要是優(yōu)化數(shù)據(jù)加載過程來提升整體訓練速度。其中也放出了一些benchmark,看上去比其他優(yōu)化庫如DALI,PyTorch Lightening要快的不少。

一方面自己是搞框架的,數(shù)據(jù)加載優(yōu)化是其中一部分重頭戲;另一方面是PyTorch的數(shù)據(jù)加載速度也被詬病很久,畢竟面對的是研究人員,大部分人都是直接opencv, PIL一把梭哈數(shù)據(jù)預處理,我也很好奇如果好好寫這部分能對PyTorch的速度提升多少,遂寫這篇文章想分析分析(如有分析不對的地方還望指正)。
代碼地址:https://github.com/libffcv/ffcv
使用文檔:https://docs.ffcv.io/index.html
Reddit相關討論:https://www.reddit.com/r/MachineLearning/comments/s781sr/p_ffcv_accelerated_model_training_via_fast_data/
快速上手
這里以提煉官方文檔為主
制作數(shù)據(jù)集
https://docs.ffcv.io/writing_datasets.html
FFCV數(shù)據(jù)集是一個自定義格式.beton,所以第一步就需要將你的數(shù)據(jù)集轉(zhuǎn)換成該格式。
這里我們以制作可索引數(shù)據(jù)集為例,首先創(chuàng)建一個支持索引的Dataset對象,你需要重寫__getitem__和__len__方法
import?numpy?as?np
class?LinearRegressionDataset:
????def?__init__(self,?N,?d):
????????self.X?=?np.random.randn(N,?d)
????????self.Y?=?np.random.randn(N)
????def?__getitem__(self,?idx):
????????return?(self.X[idx].astype('float32'),?self.Y[idx])
????def?__len__(self):
????????return?len(self.X)
N,?d?=?(100,?6)
dataset?=?LinearRegressionDataset(N,?d)
這里創(chuàng)建了一個數(shù)據(jù)集,里面樣本數(shù)量有100個,每個X維度為6,而Y維度為1
接著調(diào)用DatasetWriter將你的Dataset寫成.beton格式
from?ffcv.fields?import?NDArrayField,?FloatField
writer?=?DatasetWriter(write_path,?{
????'covariate':?NDArrayField(shape=(d,),?dtype=np.dtype('float32')),
????'label':?FloatField(),
},?num_workers=16)
write_path 數(shù)據(jù)集要寫入的路徑 字典,其中value項是你數(shù)據(jù)對應的一個Field對象。對應我們的數(shù)據(jù)集,每個X是一個ndarray,所以對應的是NDArrayField; 而Y則是一個浮點數(shù),對應FloatField
使用DataLoader
制作好數(shù)據(jù)集我們就可以用了,這里的DataLoader其實是和PyTorch的很相似,使用方法如下
loader?=?Loader('/path/to/dataset.beton',
????????????????batch_size=BATCH_SIZE,
????????????????num_workers=NUM_WORKERS,
????????????????order=ORDERING,
????????????????pipelines=PIPELINES)
order 決定數(shù)據(jù)讀取的順序 pipelines 數(shù)據(jù)預處理的pipeline,我們可以把數(shù)據(jù)增廣操作組合成一個pipeline傳進來
pipeline一個組合示例如下:
image_pipeline:?List[Operation]?=?[
????SimpleRGBImageDecoder(),
????RandomHorizontalFlip(),
????torchvision.transforms.ColorJitter(.4,.4,.4),
????RandomTranslate(padding=2),
????ToTensor(),
????ToDevice('cuda:0',?non_blocking=True),
????ToTorchImage(),
????Convert(ch.float16),
????torchvision.transforms.Normalize(MEAN,?STD),?#?Normalize?using?image?statistics
])
至此簡單介紹到這兒,我們來看下背后涉及到的一些技術(shù)
其構(gòu)造主要分為以下幾個大塊:
-?libffcv?自己寫的一套C擴展
-?ffcv?python庫主體
??|-?fields?數(shù)據(jù)結(jié)構(gòu)
??|-?loader?數(shù)據(jù)加載器
??|-?memory_manager?內(nèi)存管理器
??|-?pipeline?數(shù)據(jù)處理流水線
??|-?transformer?增廣操作
??|-?traversal_order?數(shù)據(jù)遍歷順序控制
libffcv
作者基于Python C擴展寫了一些必要的函數(shù),包括如memcpy,fileread,imdecode,resize
其中resize使用的是Opencv來做,而圖片解碼采用的是turbojpeg庫
fields
fields是ffcv里的數(shù)據(jù)結(jié)構(gòu),每個dataset的一個數(shù)據(jù),是由一個或多個fields組成,每個field需要實現(xiàn)各自的編碼,解碼邏輯,分別對應數(shù)據(jù)集的寫入,讀取操作。
以FloatField為例:
class?FloatField(Field):
????"""
????A?subclass?of?:class:`~ffcv.fields.Field`?supporting?(scalar)?floating-point?(float64)
????values.
????"""
????def?__init__(self):
????????pass
????@property
????def?metadata_type(self)?->?np.dtype:
????????return?np.dtype(')
????@staticmethod
????def?from_binary(binary:?ARG_TYPE)?->?Field:
????????return?FloatField()
????def?to_binary(self)?->?ARG_TYPE:
????????return?np.zeros(1,?dtype=ARG_TYPE)[0]
????def?encode(self,?destination,?field,?malloc):
????????destination[0]?=?field
????def?get_decoder_class(self)?->?Type[Operation]:
????????return?FloatDecoder
loader
ffcv loader對標 PyTorch DataLoader:
class?Loader:
????def?__init__(self,
?????????????????fname:?str,
?????????????????batch_size:?int,
?????????????????num_workers:?int?=?-1,
?????????????????os_cache:?bool?=?DEFAULT_OS_CACHE,
?????????????????order:?ORDER_TYPE?=?OrderOption.SEQUENTIAL,
?????????????????distributed:?bool?=?False,
?????????????????seed:?int?=?None,??#?For?ordering?of?samples
?????????????????indices:?Sequence[int]?=?None,??#?For?subset?selection
?????????????????pipelines:?Mapping[str,
????????????????????????????????????Sequence[Union[Operation,?ch.nn.Module]]]?=?{},
?????????????????custom_fields:?Mapping[str,?Type[Field]]?=?{},
?????????????????drop_last:?bool?=?True,
?????????????????batches_ahead:?int?=?3,
?????????????????recompile:?bool?=?False,??#?Recompile?at?every?epoch
?????????????????):
我們挑幾個重要的參數(shù)來說
os_cache 緩存策略 order 數(shù)據(jù)讀取順序 pipelines 數(shù)據(jù)預處理流水線,ffcv將所有的數(shù)據(jù)預處理集中到一個pipeline,然后借助JIT來加速相關處理操作 recompile 前面提到過他用JIT來加速預處理操作,當你每個epoch所對應的操作不一樣,那么你就需要重新用JIT編譯相關操作
memory_manager
這是一個內(nèi)存管理對象,當數(shù)據(jù)集能夠完全放進內(nèi)存中時,則可以通過memory_manager設置相關策略,具體有兩種策略。
一種是當內(nèi)存充裕的時候,使用OS級別的cache,這里借助了np.memmap來完成虛擬內(nèi)存和磁盤數(shù)據(jù)的映射,當出現(xiàn)缺頁異常再執(zhí)行相關的拷貝操作。
class?OSCacheContext(MemoryContext):
????def?__init__(self,?manager:MemoryManager):
????????self.manager?=?manager
????????self.mmap?=?None
????@property
????def?state(self):
????????return?(self.mmap,?self.manager.ptrs,?self.manager.sizes)
????def?__enter__(self):
????????res?=?super().__enter__()
????????if?self.mmap?is?None:
????????????self.mmap?=?np.memmap(self.manager.reader.file_name,
??????????????????????????????????'uint8',?mode='r')
????????return?res
????#?...
另一種則是用進程級別的cache,維護固定數(shù)量的page,每一個batch釋放相關的page,并對下一輪的數(shù)據(jù)進行預取prefetch。
#?We?now?find?how?many?pages?we?need?to?keep?in?our?buffer?????#?We?also?determine?where?which?page?is?going?to?reside????next_slot?=?0????page_to_slot?=?{}????free_slots?=?set()????#?For?each?batch????for?b_id?in?range(len(pages_in_batch)):????????#?First?we?free?the?pages?that?are?leaving????????for?page?in?leaving_at[b_id]:????????????free_slots.add(page_to_slot[page])????????#?We?use?the?prefetch?timing?here?because?we?want?to?be?able????????#?To?start?prefetching?ahead?of?time?and?not?overwrite?a?slot????????#?That?is?currently?used????????for?page?in?can_prefetch_at[b_id]:????????????#?Then?we?find?a?slot?for?the?incoming?pages????????????if?free_slots:????????????????#?There?is?a?slot?available?for?this?page????????????????slot?=?free_slots.pop()????????????else:????????????????#?We?have?to?allocate?a?new?slot?because?we?ran?out????????????????slot?=?next_slot????????????????next_slot?+=?1????????????page_to_slot[page]?=?slot????return?Schedule(next_slot,?page_to_slot,????????????????????can_prefetch_at,?entering_at,?leaving_at)
Pipeline
里面具體有分了幾個小部分
Operation
這是一個定義數(shù)據(jù)預處理操作的基類,其中generate_code方法用于返回相關處理操作的代碼,以便后續(xù)被jit編譯加速
class?Operation(ABC):????def?__init__(self):????????self.matadata:?np.ndarray?=?None????????self.memory_read:?Callable[[np.uint64],?np.ndarray]?=?None????????pass????????#?...????????@abstractmethod????def?declare_state_and_memory(self,?previous_state:?State)?->?Tuple[State,?Optional[AllocationQuery]]:?????????raise?NotImplementedError
Compiler
顧名思義這是一個數(shù)據(jù)加載操作的"編譯器",其思路就是利用numba.njit來將相關預處理操作編譯,進行加速
class?Compiler:????@classmethod????def?set_enabled(cls,?b):????????cls.is_enabled?=?b????@classmethod????def?set_num_threads(cls,?n):????????if?n?1?:????????????n?=?cpu_count()????????cls.num_threads?=?n????????set_num_threads(n)????????ch.set_num_threads(n)????@classmethod????def?compile(cls,?code,?signature=None):????????parallel?=?False????????if?hasattr(code,?'is_parallel'):????????????parallel?=?code.is_parallel?and?cls.num_threads?>?1????????????????if?cls.is_enabled:????????????return?njit(signature,?fastmath=True,?nogil=True,?error_model='numpy',????????????????????????parallel=parallel)(code)????????return?code
需要注意的是這里將fast_math默認開啟,在一些浮點數(shù)的情形下可能會出現(xiàn)與普通計算不一致的情況(來自多年Loss對齊的慘痛教訓)
然后我們看下 pipeline 主體代碼,這是數(shù)據(jù)預處理的流水線,主要操作是:
解析流水線
傳進來的是一系列Operation的組合,需要先調(diào)用declare_state_and_memory來分配Operation對應的state和所需memory:
def?parse_pipeline(self,?batch_size=16):????????memory_allocations:?Mapping[int,?Optional[Allocation]]?=?{}????????operation_blocs?=?[]????????current_state:?State?=?self.original_state????????current_block?=?[]????????#?We?read?the?content?of?the?pipeline,?validate?and?collect????????#?Memory?allocations????????for?op_id,?operation?in?enumerate(self.operations):????????????previous_state?=?current_state????????????current_state,?memory_allocation?=?operation.declare_state_and_memory(????????????????current_state)????????????if?current_state.jit_mode?!=?previous_state.jit_mode:????????????????if?current_block:????????????????????operation_blocs.append((previous_state.jit_mode,?current_block))????????????????current_block?=?[op_id]????????????else:????????????????current_block.append(op_id)????????????memory_allocations[op_id]?=?memory_allocation????????if?current_block:????????????operation_blocs.append((current_state.jit_mode,?current_block))????????return?operation_blocs,?memory_allocations
編譯Operation代碼
這部分很簡單,就是逐個調(diào)用每個Operation的generate_code方法
def?compile_ops(self):????compiled_ops?=?{}????for?op_id,?operation?in?enumerate(self.operations):????????compiled_ops[op_id]?=?operation.generate_code()????return?compiled_ops
這部分設計感覺是借鑒自NVIDIA DALI的Pipeline設計,F(xiàn)FCV這里借助了numba的jit特性,免去了大部分算子開發(fā),只用JIT的特性就獲取高性能,并且也易于用戶在python端自定義拓展數(shù)據(jù)預處理操作。
Transform
這里是數(shù)據(jù)增廣操作部分,通過繼承Operation類,來重寫generate_code邏輯。
以常用的ImageMixup為例:
class?ImageMixup(Operation):????def?__init__(self,?alpha:?float,?same_lambda:?bool):????????super().__init__()????????self.alpha?=?alpha????????self.same_lambda?=?same_lambda????def?generate_code(self)?->?Callable:????????alpha?=?self.alpha????????same_lam?=?self.same_lambda????????my_range?=?Compiler.get_iterator()????????def?mixer(images,?dst,?indices):????????????np.random.seed(indices[-1])????????????num_images?=?images.shape[0]????????????lam?=?np.random.beta(alpha,?alpha)?if?same_lam?else?\??????????????????np.random.beta(alpha,?alpha,?num_images)????????????for?ix?in?my_range(num_images):????????????????l?=?lam?if?same_lam?else?lam[ix]????????????????dst[ix]?=?l?*?images[ix]?+?(1?-?l)?*?images[ix?-?1]????????????return?dst????????mixer.is_parallel?=?True????????mixer.with_indices?=?True????????return?mixer????def?declare_state_and_memory(self,?previous_state:?State)?->?Tuple[State,?Optional[AllocationQuery]]:????????return?(previous_state,?AllocationQuery(shape=previous_state.shape,????????????????????????????????????????????????dtype=previous_state.dtype))

作者在Reddit上的一些討論還提到了,他們實現(xiàn)了一個更快版本的NormalizeImage操作,對應的代碼是在:https://github.com/libffcv/ffcv/blob/main/ffcv/transforms/normalize.py
實現(xiàn)具體分GPU和CPU版本,我們關注下GPU版本:
????def?__init__(self,?mean:?np.ndarray,?std:?np.ndarray,?????????????????type:?np.dtype):????????super().__init__()????????table?=?(np.arange(256)[:,?None]?-?mean[None,?:])?/?std[None,?:]????????#?...????????????def?generate_code_gpu(self)?->?Callable:????????#?We?only?import?cupy?if?it's?truly?needed????????import?cupy?as?cp????????import?pytorch_pfn_extras?as?ppe????????tn?=?np.zeros((),?dtype=self.dtype).dtype.name????????kernel?=?cp.ElementwiseKernel(f'uint8?input,?raw?{tn}?table',?f'{tn}?output',?'output?=?table[input?*?3?+?i?%?3];')????????final_type?=?ch_dtype_from_numpy(self.original_dtype)????????s?=?self????????def?normalize_convert(images,?result):????????????B,?C,?H,?W?=?images.shape????????????table?=?self.lookup_table.view(-1)????????????assert?images.is_contiguous(memory_format=ch.channels_last),?'Images?need?to?be?in?channel?last'????????????result?=?result[:B]????????????result_c?=?result.view(-1)????????????images?=?images.permute(0,?2,?3,?1).view(-1)????????????current_stream?=?ch.cuda.current_stream()????????????with?ppe.cuda.stream(current_stream):????????????????kernel(images,?table,?result_c)????????????#?Mark?the?result?as?channel?last????????????final_result?=?result.reshape(B,?H,?W,?C).permute(0,?3,?1,?2)????????????assert?final_result.is_contiguous(memory_format=ch.channels_last),?'Images?need?to?be?in?channel?last'????????????return?final_result.view(final_type)????????return?normalize_convert
這里的思路其實很巧妙,首先table是一個查找表,根據(jù)你傳來的mean和std,提前計算了0-255這256個像素值經(jīng)過歸一化后的值。
比如 mean = [127.5, 127.5, 127.5], std = [1, 1, 1],那么得到的table shape為(256, 3),其中256代表著uint8像素值從0-255,而3代表的是RGB三個通道,數(shù)據(jù)為
[[-127.5?-127.5?-127.5]?#?像素值為0,RGB三個通道對應的normalized值?[-126.5?-126.5?-126.5]?...]
此時這個查找表是channel_last形式,我們用view把他展平:
table?=?self.lookup_table.view(-1)
基于表是channel_last形式,那對應的NCHW輸入圖片我們也要進行transpose,變成對應的NHWC并展平(我猜是為了后續(xù)訪問連續(xù),從而提升性能):
images?=?images.permute(0,?2,?3,?1).view(-1)
然后就可以調(diào)用cupy的ElementwiseKernel,進行逐元素操作:
kernel?=?cp.ElementwiseKernel(f'uint8?input,?raw?{tn}?table',?f'{tn}?output',?'output?=?table[input?*?3?+?i?%?3];')
其中input是輸入像素值,i是index,這里對3取余得到具體是 RGB 3個通道中的哪一個。
總結(jié)
FFCV這個庫還是挺不錯的,不需要很多HPC知識,不需要你會寫算子,通過比較成熟的一些工具來實現(xiàn)數(shù)據(jù)加載的加速,兼顧了PyTorch DataLoader的靈活性,同時又有較高的性能。
這個庫到現(xiàn)在已經(jīng)有1.5k star了,不得不說PyTorch的生態(tài)實在是好,基于其衍生出來的拓展庫層出不窮。但也側(cè)面反應出一些問題,需要依靠社區(qū)的力量來去完善。這個庫給我們帶來了很多新思路,有興趣的朋友可以試試。

