1. Github代码
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 28 10:28:06 2020
@author: wb
"""
import torch
import torch.nn as nn
from GCN_models import GCN
from One_hot_encoder import One_hot_encoder
class SSelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SSelfAttention, self).__init__()
self.embed_size = embed_size # 64
self.heads = heads # 8
self.head_dim = embed_size // heads # 8
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N, T, C = query.shape # 难道这里的C是embed_size?
# Split the embedding into self.heads different pieces
values = values.reshape(N, T, self.heads, self.head_dim) #embed_size维拆成 heads×head_dim
keys = keys.reshape(N, T, self.heads, self.head_dim)
query = query.reshape(N, T, self.heads, self.head_dim)
values = self.values(values) # (N, T, heads, head_dim)
keys = self.keys(keys) # (N, T, heads, head_dim)
queries = self.queries(query) # (N, T, heads, heads_dim)
# Einsum does matrix mult. for query*keys for each training example
# with every other training example, don't be confused by einsum
# it's just how I like doing matrix multiplication & bmm
energy = torch.einsum("qthd,kthd->qkth", [queries, keys]) # 空间self-attention
# queries shape: (N, T, heads, heads_dim),
# keys shape: (N, T, heads, heads_dim)
# energy: (N, N, T, heads)
# Normalize energy values similarly to seq2seq + attention
# so that they sum to 1. Also divide by scaling factor for
# better stability
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=1) # 在K维做softmax,和为1
# attention shape: (N, N, T, heads)
out = torch.einsum("qkth,kthd->qthd", [attention, values]).reshape(
N, T, self.heads * self.head_dim
)
# attention shape: (N, N, T, heads)
# values shape: (N, T, heads, heads_dim)
# out after matrix multiply: (N, T, heads, head_dim), then
# we reshape and flatten the last two dimensions.
out = self.fc_out(out)
# Linear layer doesn't modify the shape, final shape will be
# (N, T, embed_size)
return out
class TSelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(TSelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N, T, C = query.shape
# Split the embedding into self.heads different pieces
values = values.reshape(N, T, self.heads, self.head_dim) # embed_size维拆成 heads×head_dim
keys = keys.reshape(N, T, self.heads, self.head_dim)
query = query.reshape(N, T, self.heads, self.head_dim)
values = self.values(values) # (N, T, heads, head_dim)
keys = self.keys(keys) # (N, T, heads, head_dim)
queries = self.queries(query) # (N, T, heads, heads_dim)
# Einsum does matrix mult. for query*keys for each training example
# with every other training example, don't be confused by einsum
# it's just how I like doing matrix multiplication & bmm
energy = torch.einsum("nqhd,nkhd->nqkh", [queries, keys]) # 时间self-attention
# queries shape: (N, T, heads, heads_dim),
# keys shape: (N, T, heads, heads_dim)
# energy: (N, T, T, heads)
# Normalize energy values similarly to seq2seq + attention
# so that they sum to 1. Also divide by scaling factor for
# better stability
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=2) # 在K维做softmax,和为1
# attention shape: (N, query_len, key_len, heads)
out = torch.einsum("nqkh,nkhd->nqhd", [attention, values]).reshape(
N, T, self.heads * self.head_dim
)
# attention shape: (N, T, T, heads)
# values shape: (N, T, heads, heads_dim)
# out after matrix multiply: (N, T, heads, head_dim), then
# we reshape and flatten the last two dimensions.
out = self.fc_out(out)
# Linear layer doesn't modify the shape, final shape will be
# (N, T, embed_size)
return out
class STransformer(nn.Module):
def __init__(self, embed_size, heads, adj, dropout, forward_expansion):
super(STransformer, self).__init__()
# Spatial Embedding
self.adj = adj
self.D_S = nn.Parameter(adj)
self.embed_liner = nn.Linear(adj.shape[0], embed_size)
self.attention = SSelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
# 调用GCN
# input:embed_size; hidden: embed_size*2; outpt:embed_size
self.gcn = GCN(embed_size, embed_size*2, embed_size, dropout)
self.norm_adj = nn.InstanceNorm2d(1) # 对邻接矩阵归一化
self.dropout = nn.Dropout(dropout)
self.fs = nn.Linear(embed_size, embed_size)
self.fg = nn.Linear(embed_size, embed_size)
def forward(self, value, key, query):
# Spatial Embedding 部分
N, T, C = query.shape
D_S = self.embed_liner(self.D_S)
D_S = D_S.expand(T, N, C)
D_S = D_S.permute(1, 0, 2)
# GCN 部分
X_G = torch.Tensor(query.shape[0], 0, query.shape[2])
self.adj = self.adj.unsqueeze(0).unsqueeze(0)
self.adj = self.norm_adj(self.adj)
self.adj = self.adj.squeeze(0).squeeze(0)
# 对每个时间步的空间特征进行GCN操作,提取每个时间步的空间特征
for t in range(query.shape[1]):
o = self.gcn(query[ : , t, : ], self.adj)
o = o.unsqueeze(1) # shape [N, 1, C]
X_G = torch.cat((X_G, o), dim=1)
# Spatial Transformer 部分 Spatial embedding加到query。 原论文采用concatenated
query = query+D_S
attention = self.attention(value, key, query)
# Add skip connection, run through normalization and finally dropout
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
# 添加residual block后加dropout,防止过拟合
U_S = self.dropout(self.norm2(forward + x))
# 融合 STransformer and GCN
g = torch.sigmoid( self.fs(U_S) + self.fg(X_G) ) # (7)
out = g*U_S + (1-g)*X_G # (8)
t = 1
return out
class TTransformer(nn.Module):
def __init__(self, embed_size, heads, time_num, dropout, forward_expansion):
super(TTransformer, self).__init__()
# Temporal embedding One hot
self.time_num = time_num
self.one_hot = One_hot_encoder(embed_size, time_num) # temporal embedding选用one-hot方式 或者
self.temporal_embedding = nn.Embedding(time_num, embed_size) # temporal embedding选用nn.Embedding
self.attention = TSelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, t):
N, T, C = query.shape # 25, 12 ,64
D_T = self.one_hot(t, N, T) # temporal embedding选用one-hot方式 或者
# (12, 64)
D_T = self.temporal_embedding(torch.arange(0, T)) # temporal embedding选用nn.Embedding
D_T = D_T.expand(N, T, C) # (25, 12, 64)
# temporal embedding加到query。 原论文采用concatenated
query = query + D_T
attention = self.attention(value, key, query)
# Add skip connection, run through normalization and finally dropout
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
class STTransformerBlock(nn.Module):
def __init__(self, embed_size, heads, adj, time_num, dropout, forward_expansion):
super(STTransformerBlock, self).__init__()
self.STransformer = STransformer(embed_size, heads, adj, dropout, forward_expansion)
# 这里为什么要传time_num ?
self.TTransformer = TTransformer(embed_size, heads, time_num, dropout, forward_expansion)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, t):
# # x1: (25, 12, 64)
# 属于Post-LN Transformer
x1 = self.norm1(self.STransformer(value, key, query) + query) # (25, 12, 64)
x2 = self.dropout( self.norm2(self.TTransformer(x1, x1, x1, t) + x1) )
return x2
class Encoder(nn.Module):
# 堆叠多层 ST-Transformer Block
def __init__(
self,
embed_size,
num_layers,
heads,
adj,
time_num,
device,
forward_expansion,
dropout,
):
super(Encoder, self).__init__()
self.embed_size = embed_size
self.device = device
self.layers = nn.ModuleList(
[
STTransformerBlock(
embed_size,
heads,
adj,
time_num,
dropout=dropout,
forward_expansion=forward_expansion
)
for _ in range(num_layers)
]
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, t):
#x: input_transformer= [25, 12, 64]
out = self.dropout(x) # out = [25, 12, 64]
# In the Encoder the query, key, value are all the same.
for layer in self.layers: # 每一个layer,就是一个STTransformer Block
# query = value = key = out
out = layer(out, out, out, t)
return out
class Transformer(nn.Module):
def __init__(
self,
adj,
embed_size=64,
num_layers=3,
heads=2,
time_num=288,
forward_expansion=4,
dropout=0,
device="cpu",
):
super(Transformer, self).__init__()
self.encoder = Encoder(
embed_size,
num_layers,
heads,
adj,
time_num,
device,
forward_expansion,
dropout,
)
self.device = device
def forward(self, src, t):
#src: input_transformer
enc_src = self.encoder(src, t)
return enc_src
class STTransformer(nn.Module):
def __init__(
self,
adj,
in_channels = 1,
embed_size = 64,
time_num = 288,
num_layers = 3,
T_dim = 12,
output_T_dim = 3,
heads = 2,
):
super(STTransformer, self).__init__()
# 第一次卷积扩充通道数
self.conv1 = nn.Conv2d(in_channels, embed_size, 1) # kernel_size = 1
self.Transformer = Transformer(
adj,
embed_size,
num_layers,
heads,
time_num
)
# 缩小时间维度。 例:T_dim=12到output_T_dim=3,输入12维降到输出3维
self.conv2 = nn.Conv2d(T_dim, output_T_dim, 1)
# 缩小通道数,降到1维。
self.conv3 = nn.Conv2d(embed_size, 1, 1)
self.relu = nn.ReLU()
def forward(self, x, t):
# input x shape[ C, N, T] = [1, 25, 12]
# C:通道数量。 N:传感器数量。 T:时间数量
x = x.unsqueeze(0) # (1, 1, 25, 12)
input_Transformer = self.conv1(x) # (1, 64, 25, 12)
input_Transformer = input_Transformer.squeeze(0) # (64, 25, 12) = (C, N, T)
input_Transformer = input_Transformer.permute(1, 2, 0) # (25, 12, 64) = [N, T, C]
# src = (25, 12, 64) = [N, T, C]
output_Transformer = self.Transformer(input_Transformer, t) # (25, 12, 64)
output_Transformer = output_Transformer.permute(1, 0, 2) # (12, 25, 64)
#output_Transformer shape[T, N, C]
output_Transformer = output_Transformer.unsqueeze(0) # (1, 12, 25, 64)
out = self.relu(self.conv2(output_Transformer)) # 等号左边 out shape: [1, output_T_dim, N, C] = [1, 3, 25, 64]
out = out.permute(0, 3, 2, 1) # 等号左边 out shape: [1, C, N, output_T_dim] = [1, 64, 25, 3]
out = self.conv3(out) # 等号左边 out shape: [1, 1, N, output_T_dim] = [1, 1, 25, 3]
out = out.squeeze(0).squeeze(0) # (25, 3)
return out
# return out shape: [N, output_dim]
2. 我的代码
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 28 10:28:06 2020
@author: wb
"""
import torch
import torch.nn as nn
from GCN_models import GCN
from One_hot_encoder import One_hot_encoder
class SSelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SSelfAttention, self).__init__()
self.embed_size = embed_size # 64
self.heads = heads # 8
self.head_dim = embed_size // heads # 8
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N, T, C = query.shape
# Split the embedding into self.heads different pieces
values = values.reshape(N, T, self.heads, self.head_dim) # embed_size维拆成 heads×head_dim
keys = keys.reshape(N, T, self.heads, self.head_dim)
query = query.reshape(N, T, self.heads, self.head_dim)
values = self.values(values) # (N, T, heads, head_dim)
keys = self.keys(keys) # (N, T, heads, head_dim)
queries = self.queries(query) # (N, T, heads, heads_dim)
energy = torch.einsum("qthd,kthd->qkth", [queries, keys]) # 空间self-attention
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=1) # 在K维做softmax,和为1
out = torch.einsum("qkth,kthd->qthd", [attention, values]).reshape(
N, T, self.heads * self.head_dim
)
out = self.fc_out(out)
# Linear layer doesn't modify the shape, final shape will be
# (N, T, embed_size)
return out
class TSelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(TSelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N, T, C = query.shape
# Split the embedding into self.heads different pieces
values = values.reshape(N, T, self.heads, self.head_dim) # embed_size维拆成 heads×head_dim
keys = keys.reshape(N, T, self.heads, self.head_dim)
query = query.reshape(N, T, self.heads, self.head_dim)
values = self.values(values) # (N, T, heads, head_dim)
keys = self.keys(keys) # (N, T, heads, head_dim)
queries = self.queries(query) # (N, T, heads, heads_dim)
# queries shape: (N, T, heads, heads_dim),
# keys shape: (N, T, heads, heads_dim)
# energy: (N, T, T, heads)
energy = torch.einsum("nqhd,nkhd->nqkh", [queries, keys]) # 时间self-attention
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=2) # 在K维做softmax,和为1
out = torch.einsum("nqkh,nkhd->nqhd", [attention, values]).reshape(
N, T, self.heads * self.head_dim
)
out = self.fc_out(out)
return out
class STransformer(nn.Module):
def __init__(self, embed_size, heads, adj, dropout, forward_expansion, time_num):
super(STransformer, self).__init__()
# Spatial Embedding
self.adj = adj
self.D_S = nn.Parameter(adj)
self.embed_liner = nn.Linear(adj.shape[0], embed_size)
self.temporal_embedding = nn.Embedding(time_num, embed_size) # temporal embedding选用nn.Embedding
self.attention = SSelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
# 调用GCN
# input:embed_size; hidden: embed_size*2; outpt:embed_size
self.gcn = GCN(embed_size, embed_size * 2, embed_size, dropout)
self.norm_adj = nn.InstanceNorm2d(1) # 对邻接矩阵归一化
self.dropout = nn.Dropout(dropout)
self.fs = nn.Linear(embed_size, embed_size)
self.fg = nn.Linear(embed_size, embed_size)
def forward(self, value, key, query):
X_S = query
# Spatial Embedding 部分
N, T, C = query.shape
D_S = self.embed_liner(self.D_S)
D_S = D_S.expand(T, N, C)
D_S = D_S.permute(1, 0, 2)
# Temporal Embedding 部分
D_T = self.temporal_embedding(torch.arange(0, T))
D_T = D_T.expand(N, T, C) # (25, 12, 64)
# GCN 部分
X_G = torch.Tensor(query.shape[0], 0, query.shape[2])
self.adj = self.adj.unsqueeze(0).unsqueeze(0)
self.adj = self.norm_adj(self.adj)
self.adj = self.adj.squeeze(0).squeeze(0)
# 对每个时间步的空间特征进行GCN操作,提取每个时间步的空间特征
for t in range(query.shape[1]):
o = self.gcn(query[:, t, :], self.adj)
o = o.unsqueeze(1) # shape [N, 1, C]
X_G = torch.cat((X_G, o), dim=1)
# Spatial Transformer 部分 Spatial embedding加到query。 原论文采用concatenated
X_tildeS = X_S + D_S + D_T
# Dynamical Graph Conv Layer 部分
query = key = value = X_tildeS
M_S = self.attention(value, key, query)
M_S = self.dropout(self.norm1(M_S + query))
M_tilderS = X_tildeS + M_S
# Add skip connection, run through normalization and finally dropout
forward = self.feed_forward(M_tilderS)
# 添加residual block后加dropout,防止过拟合
U_S = self.dropout(self.norm2(forward + M_tilderS))
# 融合 STransformer and GCN
g = torch.sigmoid(self.fs(U_S) + self.fg(X_G)) # (11)
# 按位乘
out = g * U_S + (1 - g) * X_G # (12)
return out
class TTransformer(nn.Module):
def __init__(self, embed_size, heads, time_num, dropout, forward_expansion):
super(TTransformer, self).__init__()
# Temporal embedding One hot
self.time_num = time_num
self.temporal_embedding = nn.Embedding(time_num, embed_size) # temporal embedding选用nn.Embedding
self.attention = TSelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, t):
X_T = query
N, T, C = query.shape # 25, 12 ,64
D_T = self.temporal_embedding(torch.arange(0, T))
D_T = D_T.expand(N, T, C) # (25, 12, 64)
# temporal embedding部分
X_tildeT = X_T + D_T
M_T = self.attention(X_tildeT, X_tildeT, X_tildeT)
# Add skip connection, run through normalization and finally dropout
M_tildeT = self.dropout(self.norm1(M_T + X_tildeT))
forward = self.feed_forward(M_tildeT)
U_T = self.dropout(self.norm2(forward + M_tildeT))
return U_T
class STTransformerBlock(nn.Module):
def __init__(self, embed_size, heads, adj, time_num, dropout, forward_expansion):
super(STTransformerBlock, self).__init__()
self.STransformer = STransformer(embed_size, heads, adj, dropout, forward_expansion, time_num)
# 这里为什么要传time_num ?
self.TTransformer = TTransformer(embed_size, heads, time_num, dropout, forward_expansion)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, t):
X_S = query
# 属于Post-LN Transformer
Y_S = self.norm1(self.STransformer(X_S, X_S, X_S) + X_S)
X_T = Y_S + X_S
Y_T = self.dropout(self.norm2(self.TTransformer(X_T, X_T, X_T, t) + X_T))
return Y_T
class Encoder(nn.Module):
# 堆叠多层 ST-Transformer Block
def __init__(
self,
embed_size,
num_layers,
heads,
adj,
time_num,
device,
forward_expansion,
dropout,
):
super(Encoder, self).__init__()
self.embed_size = embed_size
self.device = device
self.layers = nn.ModuleList(
[
STTransformerBlock(
embed_size,
heads,
adj,
time_num,
dropout=dropout,
forward_expansion=forward_expansion
)
for _ in range(num_layers)
]
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, t):
# x: input_transformer= [25, 12, 64]
out = self.dropout(x) # out = [25, 12, 64]
# In the Encoder the query, key, value are all the same.
for layer in self.layers: # 每一个layer,就是一个STTransformer Block
# query = value = key = out
out = layer(out, out, out, t)
return out
class Transformer(nn.Module):
def __init__(
self,
adj,
embed_size=64,
num_layers=3,
heads=2,
time_num=288,
forward_expansion=4,
dropout=0,
device="cpu",
):
super(Transformer, self).__init__()
self.encoder = Encoder(
embed_size,
num_layers,
heads,
adj,
time_num,
device,
forward_expansion,
dropout,
)
self.device = device
def forward(self, src, t):
# src: input_transformer
enc_src = self.encoder(src, t)
return enc_src
class STTransformer(nn.Module):
def __init__(
self,
adj,
in_channels=1,
embed_size=64,
time_num=288,
num_layers=3,
T_dim=12,
output_T_dim=3,
heads=2,
):
super(STTransformer, self).__init__()
# 第一次卷积扩充通道数
self.conv1 = nn.Conv2d(in_channels, embed_size, 1) # kernel_size = 1
self.Transformer = Transformer(
adj,
embed_size,
num_layers,
heads,
time_num
)
# 缩小时间维度。 例:T_dim=12到output_T_dim=3,输入12维降到输出3维
self.conv2 = nn.Conv2d(T_dim, output_T_dim, 1)
# 缩小通道数,降到1维。
self.conv3 = nn.Conv2d(embed_size, 1, 1)
self.relu = nn.ReLU()
def forward(self, x, t):
# input x shape[ C, N, T] = [1, 25, 12]
# C:通道数量。 N:传感器数量。 T:时间数量
x = x.unsqueeze(0) # (1, 1, 25, 12)
input_Transformer = self.conv1(x) # (1, 64, 25, 12)
input_Transformer = input_Transformer.squeeze(0) # (64, 25, 12) = (C, N, T)
input_Transformer = input_Transformer.permute(1, 2, 0) # (25, 12, 64) = [N, T, C]
# src = (25, 12, 64) = [N, T, C]
output_Transformer = self.Transformer(input_Transformer, t) # (25, 12, 64)
output_Transformer = output_Transformer.permute(1, 0, 2) # (12, 25, 64)
# output_Transformer shape[T, N, C]
output_Transformer = output_Transformer.unsqueeze(0) # (1, 12, 25, 64)
out = self.relu(self.conv2(output_Transformer)) # 等号左边 out shape: [1, output_T_dim, N, C] = [1, 3, 25, 64]
out = out.permute(0, 3, 2, 1) # 等号左边 out shape: [1, C, N, output_T_dim] = [1, 64, 25, 3]
out = self.conv3(out) # 等号左边 out shape: [1, 1, N, output_T_dim] = [1, 1, 25, 3]
out = out.squeeze(0).squeeze(0) # (25, 3)
return out
# return out shape: [N, output_dim]
版权声明:本文为qq_43118572原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。