tensorflow 2.0 버전부터 model, layer subclassing을 지원한다. 이때 custom model을 학습 후 저장할 때 model.save()
를 사용하면 아래와 같은 에러 문구가 뜨는 경우가 발생하는 경우가 있다.
>>> ValueError: Model <__main__.CustomModel object at 0x7f96797a2c10> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling .fit() or .predict(). To manually set the shapes, call model._set_inputs(inputs).
이 경우 model.save_weights()
를 사용하면 해결이 된다. 두 method의 차이점을 알아보자.
1) model.save()
class CustomModel(tf.keras.Model):
def __init__(self, hidden_units):
super(CustomModel, self).__init__()
self.dense_layers = [
tf.keras.layers.Dense(u) for u in hidden_units]
def call(self, inputs):
x = inputs
for layer in self.dense_layers:
x = layer(x)
return x
model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs=model(input_arr)
model.save('my_custom_model')
model.save()
를 호출하면 다음과 같은 파일들이 저장된다.
- model's architecture/config
- model's weight values
- model's compilation information (if compile() was called)
- optimizer and its state
model.save('my_custom_model')
시 저장되는 디렉토리는 다음과 같다.
my_custom_model
└ assets (dir)
└ variables (dir)
└ saved_model.pb (file)
model.save()
와 같이 사용하는 save & load API는 다음과 같다.
model.save()
또는tf.keras.models.save_model()
tf.keras.models.load_model()
2) model.save_weights()
model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs=model(input_arr)
model.save_weights('my_custom_model')
model.save_weights()
를 호출하면 tf.train.Checkpoint
을 사용할 때와 동일한 파일들이 저장된다. tensorflow 문서에서는 checkpoint를 통한 저장을 권장하고 있다.
좀 짜증나는 것은 문서에서는 file_path를 명시하고 있는데 정작 model.save_weights('my_custom_model')
을 호출하면 파일들의 prefix로 붙어서 저장된다.
>>> model.save_weights('my_custom_model')
checkpoint
my_custom_model.data-00000-of-00001
my_custom_model.index
때문에 위 3개 파일이 저장될 폴더를 지정해주고 os.path.join()
으로 file_path를 지정해줘야 한다.
model.save_weights()
와 같이 사용하는 save & load API는 다음과 같다.
model.load_weights()
[reference]
keras - serialization and saving(링크)
tf.keras.Model documentation(링크)
'pytorch, tensorflow' 카테고리의 다른 글
[PyTorch] Embedding 추가하기 (0) | 2022.10.01 |
---|---|
[PyTorch] torch.cat(), torch.stack() 비교 (0) | 2022.10.01 |
[pytorch] torch.reshape에 관하여 (0) | 2022.07.11 |
[pytorch] torch에서 parameter 접근하기 (0) | 2021.03.19 |