docs: 添加题目名称技术路线预期结果文档
- 新增毕业设计题目说明、技术路线规划和预期结果描述 - 优化深度学习模型代码,支持 PyTorch 可选依赖
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@@ -58,6 +59,9 @@ DEFAULT_BATCH_SIZE = 256
|
||||
EARLY_STOPPING_PATIENCE = 12
|
||||
|
||||
|
||||
BaseTorchModule = nn.Module if nn is not None else object
|
||||
|
||||
|
||||
class SequenceStaticDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -86,7 +90,7 @@ class SequenceStaticDataset(Dataset):
|
||||
)
|
||||
|
||||
|
||||
class LSTMMLPRegressor(nn.Module):
|
||||
class LSTMMLPRegressor(BaseTorchModule):
|
||||
def __init__(
|
||||
self,
|
||||
seq_num_dim: int,
|
||||
|
||||
Reference in New Issue
Block a user