函数原型

1
tf.keras.layers.Embedding(input_dim, output_dim, embeddings_initializer='uniform',    embeddings_regularizer=None, activity_regularizer=None,    embeddings_constraint=None, mask_zero=False, input_length=None, **kwargs)

举个图像数据的例子

这里的input_shape(10)output_shapenp.prod((28, 28, 1))

1
2
3
4
5
6
7
from tensorflow.keras.layers import multiply,Flatten,Embedding

img=tf.ones((2,28,28,1))#两张图片
label=np.array([1,2])#两个标签

label_embedding = Flatten()(Embedding(10, np.prod((28, 28, 1)))(label))
flat_img = Flatten()(img)
1
print(label_embedding.shape)
1
(2, 784)
1
print(flat_img.shape)
1
(2, 784)
1
model_input = multiply([flat_img, label_embedding])
1
print(model_input.shape)
1
TensorShape([2, 784])

再举个一维噪声的例子

1
2
3
4
5
6
from tensorflow.keras.layers import multiply,Flatten,Embedding

noise=tf.ones((2,100))#两份噪声
label=np.array([1,2])#两个标签

label_embedding = Flatten()(Embedding(10, 100)(label))
1
print(label_embedding.shape)
1
(2, 100)
1
print(noise.shape)
1
(2, 100)
1
2
model_input = multiply([noise, label_embedding])
model_input.shape
1
TensorShape([2, 100])