SpikingJelly脈沖神經網絡深度學習框架
SpikingJelly 是一個基于 PyTorch,使用脈沖神經網絡 (Spiking Network, SNN) 進行深度學習的框架。
SpikingJelly 非常易于使用。使用 SpikingJelly 搭建 SNN,就像使用 PyTorch 搭建 ANN 一樣簡單:
class Net(nn.Module):
def __init__(self, tau=100.0, v_threshold=1.0, v_reset=0.0):
super().__init__()
# 網絡結構,簡單的雙層全連接網絡,每一層之后都是LIF神經元
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 14 * 14, bias=False),
neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset),
nn.Linear(14 * 14, 10, bias=False),
neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset)
)
def forward(self, x):
return self.fc(x)
設備支持
- Nvidia GPU
- CPU
像使用 PyTorch 一樣簡單。
>>> net = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10, bias=False), neuron.LIFNode(tau=tau))
>>> net = net.to(device) # Can be CPU or CUDA devices
神經形態(tài)數據集支持
SpikingJelly 已經將下列數據集納入:
| 數據集 | 來源 |
|---|---|
| ASL-DVS | Graph-based Object Classification for Neuromorphic Vision Sensing |
| CIFAR10-DVS | CIFAR10-DVS: An Event-Stream Dataset for Object Classification |
| DVS128 Gesture | A Low Power, Fully Event-Based Gesture Recognition System |
| N-Caltech101 | Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades |
| N-MNIST | Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades |
用戶可以輕松使用事件數據,或由 SpikingJelly 積分生成的幀數據:
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
root_dir = 'D:/datasets/DVS128Gesture'
event_set = DVS128Gesture(root_dir, train=True, data_type='event')
frame_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')
未來將會納入更多數據集。
評論
圖片
表情
