pytorch, tensorflow

[tensorflow 2.0] model.save, model.save_weights 차이

2021. 4. 1. 17:40
목차
  1. 1) model.save()
  2. 2) model.save_weights()

tensorflow 2.0 버전부터 model, layer subclassing을 지원한다. 이때 custom model을 학습 후 저장할 때 model.save()를 사용하면 아래와 같은 에러 문구가 뜨는 경우가 발생하는 경우가 있다.

 

>>> ValueError: Model <__main__.CustomModel object at 0x7f96797a2c10> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling .fit() or .predict(). To manually set the shapes, call model._set_inputs(inputs).

 

이 경우 model.save_weights()를 사용하면 해결이 된다. 두 method의 차이점을 알아보자.

 

 

1) model.save()

class CustomModel(tf.keras.Model):
    def __init__(self, hidden_units):
        super(CustomModel, self).__init__()
        self.dense_layers = [
            tf.keras.layers.Dense(u) for u in hidden_units]
    def call(self, inputs):
        x = inputs
        for layer in self.dense_layers:
            x = layer(x)
        return x

model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs=model(input_arr)
model.save('my_custom_model')

 

model.save()를 호출하면 다음과 같은 파일들이 저장된다.

  • model's architecture/config
  • model's weight values
  • model's compilation information (if compile() was called)
  • optimizer and its state

 

model.save('my_custom_model')시 저장되는 디렉토리는 다음과 같다.

my_custom_model
  └ assets (dir)
  └ variables (dir)
  └ saved_model.pb (file)

 

model.save()와 같이 사용하는 save & load API는 다음과 같다.

  • model.save() 또는 tf.keras.models.save_model()
  • tf.keras.models.load_model()

 

 

2) model.save_weights()

model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs=model(input_arr)
model.save_weights('my_custom_model')

model.save_weights()를 호출하면 tf.train.Checkpoint을 사용할 때와 동일한 파일들이 저장된다. tensorflow 문서에서는 checkpoint를 통한 저장을 권장하고 있다.

 

좀 짜증나는 것은 문서에서는 file_path를 명시하고 있는데 정작 model.save_weights('my_custom_model')을 호출하면 파일들의 prefix로 붙어서 저장된다. 

>>> model.save_weights('my_custom_model')

checkpoint
my_custom_model.data-00000-of-00001
my_custom_model.index

때문에 위 3개 파일이 저장될 폴더를 지정해주고 os.path.join()으로 file_path를 지정해줘야 한다.

 

model.save_weights()와 같이 사용하는 save & load API는 다음과 같다.

  • model.load_weights()

 

[reference]

keras - serialization and saving(링크)

tf.keras.Model documentation(링크)

728x90
저작자표시 비영리 동일조건 (새창열림)

'pytorch, tensorflow' 카테고리의 다른 글

[PyTorch] Embedding 추가하기  (0) 2022.10.01
[PyTorch] torch.cat(), torch.stack() 비교  (0) 2022.10.01
[pytorch] torch.reshape에 관하여  (0) 2022.07.11
[pytorch] torch에서 parameter 접근하기  (0) 2021.03.19
  1. 1) model.save()
  2. 2) model.save_weights()
'pytorch, tensorflow' 카테고리의 다른 글
  • [PyTorch] Embedding 추가하기
  • [PyTorch] torch.cat(), torch.stack() 비교
  • [pytorch] torch.reshape에 관하여
  • [pytorch] torch에서 parameter 접근하기
Fine애플
Fine애플
이것저것
끄적끄적이것저것
Fine애플
끄적끄적
Fine애플
전체
오늘
어제
  • 분류 전체보기 (167)
    • 논문 및 개념 정리 (27)
    • Pattern Recognition (8)
    • 개발 (57)
    • python 메모 (45)
    • pytorch, tensorflow (5)
    • 알고리즘 (9)
    • Toy Projects (4)
    • 통계이론 (2)
    • Reinforcement Learning (10)

블로그 메뉴

  • 홈

공지사항

인기 글

태그

  • ubuntu
  • transformer
  • 자연어
  • tensorflow
  • 언어모델
  • Bert
  • reinforcement learning
  • Probability
  • BigBird
  • GPU
  • 알고리즘
  • pandas
  • python
  • PyTorch
  • miniconda
  • 개발환경
  • Docker
  • 딥러닝
  • container
  • nlp

최근 댓글

최근 글

hELLO · Designed By 정상우.
Fine애플
[tensorflow 2.0] model.save, model.save_weights 차이
상단으로

티스토리툴바

개인정보

  • 티스토리 홈
  • 포럼
  • 로그인

단축키

내 블로그

내 블로그 - 관리자 홈 전환
Q
Q
새 글 쓰기
W
W

블로그 게시글

글 수정 (권한 있는 경우)
E
E
댓글 영역으로 이동
C
C

모든 영역

이 페이지의 URL 복사
S
S
맨 위로 이동
T
T
티스토리 홈 이동
H
H
단축키 안내
Shift + /
⇧ + /

* 단축키는 한글/영문 대소문자로 이용 가능하며, 티스토리 기본 도메인에서만 동작합니다.