반응형
    
    
    
  MNIST 모델을 머신러닝을 통해 학습하여 모델 파일을
저장하게 되는데 이 모델파일에 대한 테스트를 진행 하고자
테스터 함수를 만들게 되었다.
이 함수는 test_set 폴더에 있는 이미지들을 이용하여
테스트를 진행하고 error rate를 계산하여 출력해 보았다.
//
//  MNIST 모델 테스터
//
//  Created by netcanis on 2023/07/20.
//
import cv2
import os
import numpy as np
import pickle
import tensorflow as tf
from keras.models import load_model
def test_model(path, model_file):
    print("Testing in progress...")
    
    # 데이터셋 경로
    test_set_path = os.path.join(path, 'test_set')
    
    total_samples = 0
    error_count = 0
    
    # Load the model from file
    with open(model_file, "rb") as file:
        if model_file.endswith('.h5'):
            loaded_model = load_model(model_file)
        elif model_file.endswith('.tflite'):
            interpreter = tf.lite.Interpreter(model_path = model_file)
            interpreter.allocate_tensors()
            # Get the input and output details
            input_details = interpreter.get_input_details()
            output_details = interpreter.get_output_details()
        else:
            loaded_model = pickle.load(file)
        
        
    # Load the images from the test set
    for digit_folder in os.listdir(test_set_path):
        if os.path.isdir(os.path.join(test_set_path, digit_folder)):
            label = int(digit_folder)
            for index, image_file in enumerate(os.listdir(os.path.join(test_set_path, digit_folder))):
                if image_file.endswith('.png') or image_file.endswith('.jpg'):
                    image = cv2.imread(os.path.join(test_set_path, digit_folder, image_file))
                    if path.endswith('/credit_card'):
                        image = cv2.resize(image, (32, 32))
                    else: # '/MNIST'
                        image = cv2.resize(image, (28, 28))
                    # Convert color image to grayscale if necessary
                    if image.shape[2] > 1:
                        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
                            
                    if model_file.endswith('.h5'):
                        image = image.reshape(1, 28, 28, 1) # 배열을 4차원으로 변경.
                        image = image / 255.0
                        # Predict the label for the image
                        predicted_label = loaded_model.predict(image)
                        # Get the predicted class
                        predicted_class = np.argmax(predicted_label)
                    elif model_file.endswith('.tflite'):
                        image = np.expand_dims(image, axis=0)  # Add batch dimension
                        image = np.expand_dims(image, axis=3)  # Add channel dimension
                        image = image.astype(np.float32) / 255.0
                        # Set the input tensor
                        interpreter.set_tensor(input_details[0]['index'], image)
                        # Run the inference
                        interpreter.invoke()
                        # Get the output tensor
                        output_tensor = interpreter.get_tensor(output_details[0]['index'])
                        # Get the predicted class
                        predicted_class = np.argmax(output_tensor)
                    else: # SVM, RandomForestClassifier ('.pkl')                       
                        # Reshape the data
                        image = image.reshape(1, -1)
                        # Predict the label for the image
                        predicted_label = loaded_model.predict(image)
                        # Get the predicted class
                        predicted_class = predicted_label[0]
                    
                    # error 
                    if predicted_class != label:
                        error_count += 1
                        print(f"Prediction for {index} - {label}: {predicted_class}")
                    
                    total_samples += 1
    # Print error rate
    error_rate = (error_count / total_samples) * 100
    print(f"Error rate: {error_rate:.2f}%")
사용 방법은 다음과 같다.
import mnist_model_tester
model_file = "svm_mnist_model.pkl"
mnist_model_tester.test_mnist_model("data/MNIST", model_file)
or
from mnist_model_tester import test_mnist_model
model_file = "svm_mnist_model.pkl"
test_mnist_model("data/MNIST", model_file)
2023.07.19 - [AI] - MNIST 데이터셋 다운로드
2023.07.19 - [AI] - MNIST 데이터셋을 이미지 파일로 복원
2023.07.19 - [AI] - MNIST 데이터셋 로더
2023.07.19 - [AI] - MNIST 모델 테스터
2023.07.19 - [AI] - MINST - SVC(Support Vector Classifier)
2023.07.19 - [AI] - MNIST - RandomForestClassifier
2023.07.19 - [AI] - MNIST - Keras
2023.07.19 - [AI] - MNIST - TensorFlowLite
반응형
    
    
    
  '개발 > AI,ML,ALGORITHM' 카테고리의 다른 글
| MNIST - RandomForestClassifier (0) | 2023.07.19 | 
|---|---|
| MINST - SVC(Support Vector Classifier) (0) | 2023.07.19 | 
| MNIST 데이터셋 로더 (0) | 2023.07.19 | 
| MNIST 데이터셋을 이미지 파일로 복원 (0) | 2023.07.19 | 
| MNIST 데이터셋 다운로드 (0) | 2023.07.19 |