반응형

MNIST 모델을 머신러닝을 통해 학습하여 모델 파일을

저장하게 되는데 이 모델파일에 대한 테스트를 진행 하고자

테스터 함수를 만들게 되었다.

이 함수는 test_set 폴더에 있는 이미지들을 이용하여

테스트를 진행하고  error rate를 계산하여 출력해 보았다.

 

//
//  MNIST 모델 테스터
//
//  Created by netcanis on 2023/07/20.
//

import cv2
import os
import numpy as np
import pickle
import tensorflow as tf
from keras.models import load_model


def test_model(path, model_file):
    print("Testing in progress...")
    
    # 데이터셋 경로
    test_set_path = os.path.join(path, 'test_set')
    
    total_samples = 0
    error_count = 0
    
    # Load the model from file
    with open(model_file, "rb") as file:
        if model_file.endswith('.h5'):
            loaded_model = load_model(model_file)
        elif model_file.endswith('.tflite'):
            interpreter = tf.lite.Interpreter(model_path = model_file)
            interpreter.allocate_tensors()
            # Get the input and output details
            input_details = interpreter.get_input_details()
            output_details = interpreter.get_output_details()
        else:
            loaded_model = pickle.load(file)
        
        
    # Load the images from the test set
    for digit_folder in os.listdir(test_set_path):
        if os.path.isdir(os.path.join(test_set_path, digit_folder)):
            label = int(digit_folder)
            for index, image_file in enumerate(os.listdir(os.path.join(test_set_path, digit_folder))):
                if image_file.endswith('.png') or image_file.endswith('.jpg'):

                    image = cv2.imread(os.path.join(test_set_path, digit_folder, image_file))
                    if path.endswith('/credit_card'):
                        image = cv2.resize(image, (32, 32))
                    else: # '/MNIST'
                        image = cv2.resize(image, (28, 28))

                    # Convert color image to grayscale if necessary
                    if image.shape[2] > 1:
                        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
                            
                    if model_file.endswith('.h5'):
                        image = image.reshape(1, 28, 28, 1) # 배열을 4차원으로 변경.
                        image = image / 255.0
                        # Predict the label for the image
                        predicted_label = loaded_model.predict(image)
                        # Get the predicted class
                        predicted_class = np.argmax(predicted_label)
                    elif model_file.endswith('.tflite'):
                        image = np.expand_dims(image, axis=0)  # Add batch dimension
                        image = np.expand_dims(image, axis=3)  # Add channel dimension
                        image = image.astype(np.float32) / 255.0
                        # Set the input tensor
                        interpreter.set_tensor(input_details[0]['index'], image)
                        # Run the inference
                        interpreter.invoke()
                        # Get the output tensor
                        output_tensor = interpreter.get_tensor(output_details[0]['index'])
                        # Get the predicted class
                        predicted_class = np.argmax(output_tensor)
                    else: # SVM, RandomForestClassifier ('.pkl')                       
                        # Reshape the data
                        image = image.reshape(1, -1)
                        # Predict the label for the image
                        predicted_label = loaded_model.predict(image)
                        # Get the predicted class
                        predicted_class = predicted_label[0]
                    
                    # error 
                    if predicted_class != label:
                        error_count += 1
                        print(f"Prediction for {index} - {label}: {predicted_class}")
                    
                    total_samples += 1


    # Print error rate
    error_rate = (error_count / total_samples) * 100
    print(f"Error rate: {error_rate:.2f}%")

 

사용 방법은 다음과 같다.

import mnist_model_tester

model_file = "svm_mnist_model.pkl"
mnist_model_tester.test_mnist_model("data/MNIST", model_file)

or

from mnist_model_tester import test_mnist_model

model_file = "svm_mnist_model.pkl"
test_mnist_model("data/MNIST", model_file)

 

2023.07.19 - [AI] - MNIST 데이터셋 다운로드

2023.07.19 - [AI] - MNIST 데이터셋을 이미지 파일로 복원

2023.07.19 - [AI] - MNIST 데이터셋 로더

2023.07.19 - [AI] - MNIST 모델 테스터

2023.07.19 - [AI] - MINST - SVC(Support Vector Classifier)

2023.07.19 - [AI] - MNIST - RandomForestClassifier

2023.07.19 - [AI] - MNIST - Keras

2023.07.19 - [AI] - MNIST - TensorFlowLite

 

 

반응형

'개발 > AI,ML,ALGORITHM' 카테고리의 다른 글

MNIST - RandomForestClassifier  (0) 2023.07.19
MINST - SVC(Support Vector Classifier)  (0) 2023.07.19
MNIST 데이터셋 로더  (0) 2023.07.19
MNIST 데이터셋을 이미지 파일로 복원  (0) 2023.07.19
MNIST 데이터셋 다운로드  (0) 2023.07.19
블로그 이미지

SKY STORY

,