函数原型
1
| tf.keras.layers.Embedding(input_dim, output_dim, embeddings_initializer='uniform', embeddings_regularizer=None, activity_regularizer=None, embeddings_constraint=None, mask_zero=False, input_length=None, **kwargs)
|
举个图像数据的例子
这里的input_shape是(10),output_shape是np.prod((28, 28, 1))
1 2 3 4 5 6 7
| from tensorflow.keras.layers import multiply,Flatten,Embedding
img=tf.ones((2,28,28,1)) label=np.array([1,2])
label_embedding = Flatten()(Embedding(10, np.prod((28, 28, 1)))(label)) flat_img = Flatten()(img)
|
1
| print(label_embedding.shape)
|
1
| model_input = multiply([flat_img, label_embedding])
|
1
| print(model_input.shape)
|
再举个一维噪声的例子
1 2 3 4 5 6
| from tensorflow.keras.layers import multiply,Flatten,Embedding
noise=tf.ones((2,100)) label=np.array([1,2])
label_embedding = Flatten()(Embedding(10, 100)(label))
|
1
| print(label_embedding.shape)
|
1 2
| model_input = multiply([noise, label_embedding]) model_input.shape
|