약초의 숲으로 놀러오세요

Pytorch로 구현한 R-CNN 모델 본문

Computer Vision/Code Review

Pytorch로 구현한 R-CNN 모델

herbwood 2020. 11. 28. 17:54

이번 포스팅에서는 R-CNN 모델을 pytorch를 통해 구현한 코드를 살펴보도록 하겠습니다. 아직 코드 구현에 익숙치 않아 object-detection-algorithm님의 github 저장소에 올라온 R-CNN 모델 구현 코드를 분석했습니다. R-CNN 모델에 대한 설명은 R-CNN 논문 리뷰 포스팅을 참고하시기 바랍니다. 

 

               R-CNN code structure

먼저 object-detection-algorithm(이하 oda)님은 PASCAL VOC 2007 데이터셋의 여러 class 중 "car"에 해당하는 데이터만을 추출하여 사용합니다. 전체 데이터셋을 다 사용할 경우 많은 시간이 걸리기 때문에 특정 class만 추출해서 사용하는 것 같습니다. 

 

R-CNN 모델은 fine tuned AlexNet, linear SVM, Bounding box regressor, 총 3가지 모델을 사용합니다. 각 모델은 데이터에 대한 positive/negative 정의를 다르게 하고 있기 때문에, 학습에 사용하는 데이터셋이 서로 상이합니다. oda님은 위의 그림에서 볼 수 있다시피 서로 다른 모델에 대한 데이터셋을 독립적으로 구축하고, 각각의 방식에 맞는 Custom Dataset을 정의합니다. 사실 모델을 학습하는 부분보다도 서로 다른 데이터셋을 구축하고 load하는 과정이 복잡하게 느껴졌습니다. 코드를 뜯어보고 나서야 R-CNN 모델이 복잡하다는 말을 체감하게 되었습니다😭.

 

그 다음 3가지 모델을 각각 학습 시킨 후 load시켜 추론 시 사용한 후 Non maximum suppression 알고리즘을 적용하여 최종 detection 결과를 반환합니다. 프로젝트 코드 전체 구성은 아래와 같습니다. 

 

|-docs  
|-imgs  
|-py  
  |-data
  |-utils
     |-data
        |-create_bbox_regression_data.py          # Bounding box regressor 학습 데이터 생성
        |-create_classifier_data.py               # linear SVM 학습 데이터 생성
        |-create_finetune_data.py                 # AlexNet fine tune용 데이터 생성
        |-custom_batch_sampler.py                 # mini batch 구성 정의
        |-custom_bbox_regression_dataset.py       # Bounding box regressor custom data loader
        |-custom_classifier_dataset.py            # linear SVM custom data loader
        |-custom_finetune_dataset.py              # AlexNet fine tune custom data loader
        |-custom_hard_negative_mining_dataset.py  # hard negative mining 정의
        |-pascal_voc.py                           # PASCAL VOC 2007 데이터셋 다운로드
        |-pascal_voc_car.py                       # PASCAL VOC 2007 데이터셋에서 car에 해당하는 데이터만 추출
     |-util.py                                    # 기타 메서드 정의
  |-bbox_regression.py                            # Bounding box regressor 학습
  |-car_detector.py                               # 학습시킨 모델을 활용하여 detection
  |-finetune.py                                   # fine tune AlexNet
  |-linear_svm.py                                 # linear SVM 학습
  |-selectivesearch.py                            # Selective search 알고리즘 수행

 

구조가 상당히 복잡하기 때문에 아래와 같은 순서로 살펴보도록 하겠습니다. 코드 전체를 뜯어보기에는 양이 너무 많기 때문에, 제가 분석하면서 중요하다고 생각했던 부분 위주로 살펴보겠습니다. 

 

R-CNN 모델 설계 순서

 

  1. PASCAL VOC 데이터셋 다운로드 및 "car" class 해당하는 데이터만 추출
  2. 각 모델별 annotation 생성 및 Custom Dataset 정의
  3. pre-trained된 AlexNet fine tuning
  4. linear SVM 및 Bounding box regressor 모델 학습
  5. 3가지 모델을 모두 활용하여 detection 수행

1) PASCAL VOC 데이터셋 다운로드 및 Car class에 해당하는 데이터만 추출

  • pascal_voc.py : PASCAL VOC 2007 데이터셋 다운로드
  • pascal_voc_car.py : PASCAL VOC 2007 데이터셋에서 "car"에 해당하는 데이터(이미지, annotation)만 추출

PASCAL VOC 2007 데이터셋은 아래와 같은 구조를 가집니다. 

 

VOC2007
├── Annotations
├── ImageSets
├── JPEGImages
├── SegmentationClass
└── SegmentationObject

 

여기서 JPEGImages에는 모든 이미지가 jpg 형식으로 저장되어 있으며, Annotations에는 각 이미지 파일에 존재하는 class명, 이미지 크기, bounding box 크기 등이 xml 파일로 저장되어 있습니다. ImageSets 디렉터리에는 각 class 이름에 해당하는 텍스트 파일이 존재합니다. 텍스트 파일에는 모든 이미지 파일명과 텍스트 파일명과 동일한 class에 해당하는지 여부(해당할 경우 1, 아닐 경우 -1)가 저장되어 있습니다(ex) 1111 -1).

 

pascal_voc_car.py에서는 ImageSets에 있는 car_trainval.txt 파일을 읽어들여 car에 해당하는 이미지와 xml 파일을 복사하여 별도의 데이터셋을 구축합니다. 

2) 각 모델별 annotation 생성 및 Custom Dataset 정의

  • selectivesearch.py : Selective search 알고리즘 수행
  • create_finetune_data.py : AlexNet fine tune을 수행하기 위한 annotation 생성
  • create_classifier_data.py : linear SVM 학습을 위한 annotation 생성
  • create_bbox_regression_data.py : Bounding box regressor 학습을 위한 annotation 생성

Selective search 알고리즘은 opencv에서 제공하는 메서드를 통해 수행합니다. AlexNet 모델을 fine tuning하기 위해 데이터셋의 annotation을 생성해주는 코드인 create_finetune_data.py는 아래와 같은 순서로 동작합니다.

 

def parse_annotation_jpeg(annotation_path, jpeg_path, gs):

    img = cv2.imread(jpeg_path)

    selectivesearch.config(gs, img, strategy='q')
    rects = selectivesearch.get_rects(gs) # region proposals
    bndboxs = parse_xml(annotation_path) # ground truth boxes

    # get size of the biggest bounding box(region proposals)
    maximum_bndbox_size = 0
    for bndbox in bndboxs:
        xmin, ymin, xmax, ymax = bndbox
        bndbox_size = (ymax - ymin) * (xmax - xmin)
        if bndbox_size > maximum_bndbox_size:
            maximum_bndbox_size = bndbox_size

    # Comparing all region proposals and ground truth
    # return a list of iou results for each region proposals
    iou_list = compute_ious(rects, bndboxs)

    positive_list = list()
    negative_list = list()

    for i in range(len(iou_list)):
        xmin, ymin, xmax, ymax = rects[i]
        rect_size = (ymax - ymin) * (xmax - xmin)

        iou_score = iou_list[i]

        # When fine-tuning the pre-trained CNN model
        # positive : iou >= 0.5
        # negative : iou < 0.5
        # Only the bounding box with iou greater than 0.5 is saved
        if iou_score >= 0.5:
            positive_list.append(rects[i])

        # negative : iou < 0.5 And if it is more than 20% of the largest bounding box
        if 0 < iou_list[i] < 0.5 and rect_size > maximum_bndbox_size / 5.0:
            negative_list.append(rects[i])
        else:
            pass

    return positive_list, negative_list

 

1. 이미지에 Selective search 알고리즘을 적용하여 region proposals를 추출합니다. 그리고 해당 이미지에 대한 xml 파일을 읽어들여 ground truth box를 파악합니다.

2. region proposals와 ground truth box를 비교하여 IoU 값을 도출하고 0.5 이상인 sample은 positive_list, 0.5 미만인 sample은 negative_list에 저장합니다.

3. 그리고 이미지별 region proposal에 대한 위치를positive/negative 여부에 따라 서로다른 csv 파일에 저장합니다. 예를 들어 1111.jpg 파일에서 positive sample에 해당하는 bounding box의 좌표는 1111_1.csv 파일에, negative sample에 해당하는 bounding box는 1111_0.csv 파일에 저장합니다. 

 

위의 과정은 create_classifier_data.pycreate_bbox_regressor_data.py에서도 비슷하게 동작하지만 positive/negative sample에 대한 정의만 다릅니다. 모델에 따른 서로 다른 양성/음성 정의는 이전 포스팅을 살펴보시면 좋을 것 같습니다. 다음으로 모델별 Custom Dataset을 정의하도록 하겠습니다. 

 

  • custom_finetune_dataset.py : AlexNet fine tune하기 위한 Custom Dataset 정의
  • custom_classifier_dataset.py : linear SVM 모델을 학습시키기 위한 Custom Dataset 정의
  • custom_bbox_regression_dataset.py : Bounding box regressor 모델을 학습시키기 위한 Custom Dataset 정의
  • custom_batch_sampler.py : 양성/음성 sample을 mini batch로 구성
  • custom_hard_negative_mining_dataset.py : Hard negative mining을 수행하기 위한 Custom Dataset 정의

custom_finetune_dataset.py는 모델을 학습시키기 위한 대상을 정의합니다. AlexNet 모델은 양성/음성 sample에 해당하는 이미지를 mini batch로 입력받아 학습합니다. CustomFinetuneDataset에서 생성자(__init__)를 통해 앞서 생성한 csv 파일을 읽어들어 postive sample과 negative sample을 서로 다른 리스트에 저장합니다. 그리고 __getitem__ 메서드는 index를 파라미터로 받아, index에 맞는 이미지와 target(양성/음성 여부)를 반환합니다. 

 

class CustomFinetuneDataset(Dataset):

    def __init__(self, root_dir, transform=None):
        samples = parse_car_csv(root_dir)

        # load all car images
        jpeg_images = [cv2.imread(os.path.join(root_dir, 'JPEGImages', sample_name + ".jpg"))
                       for sample_name in samples]

        # positive : iou >= 0.5 negative : iou < 0.5
        # Save positive and negative separately
        positive_annotations = [os.path.join(root_dir, 'Annotations', sample_name + '_1.csv')
                                for sample_name in samples]
        negative_annotations = [os.path.join(root_dir, 'Annotations', sample_name + '_0.csv')
                                for sample_name in samples]

        # bounding box sizes
        positive_sizes = list()
        negative_sizes = list()

        # bounding box coordinates
        positive_rects = list()
        negative_rects = list()

        # positive_rects = [(x, y, w, h), ....]
        # positive_sizes = [1, .....]
        for annotation_path in positive_annotations:
            rects = np.loadtxt(annotation_path, dtype=np.int, delimiter=' ')
            # The existing file is empty or there is only a single line of data in the file
            if len(rects.shape) == 1:
                # Single line
                if rects.shape[0] == 4:
                    positive_rects.append(rects)
                    positive_sizes.append(1)
                else:
                    positive_sizes.append(0)
            else:
                positive_rects.extend(rects)
                positive_sizes.append(len(rects))
                
  (...)

 

custom_classifier_dataset.py 도 이와 같은 방식으로 동작합니다. 하지만 Bounding box regressor의 경우, 이미지 자체를 학습 데이터로 사용하는 것이 아니라 이미지 내의 bounding box의 좌표를 변환시켜주는 값(t_x, t_y)을 학습합니다. 따라서 custom_bbox_regression_dataset.py에서는 아래와 같이 bounding box의 좌표를 입력으로 받아 적절히 변환시켜주는 과정이 필요합니다. 

 

class BBoxRegressionDataset(Dataset):

    def __init__(self, root_dir, transform=None):
(...)

    def __getitem__(self, index: int):
        assert index < self.__len__(), 'The data set size is %d, 
        	and the current input subscript is%d' % (self.__len__(), index)

        box_dict = self.box_list[index]
        image_id = box_dict['image_id']
        positive = box_dict['positive']
        bndbox = box_dict['bndbox']

        # Get predicted image
        jpeg_img = self.jpeg_list[image_id]
        xmin, ymin, xmax, ymax = positive
        image = jpeg_img[ymin:ymax, xmin:xmax]

        if self.transform:
            image = self.transform(image)

        # Calculate x/y/w/h of P/G
        # predicted box width, heigth, centerX coord, centerY coord
        target = dict()
        p_w = xmax - xmin
        p_h = ymax - ymin
        p_x = xmin + p_w / 2
        p_y = ymin + p_h / 2

        # ground truth box width, height, centerX coord, centerY coord
        xmin, ymin, xmax, ymax = bndbox
        g_w = xmax - xmin
        g_h = ymax - ymin
        g_x = xmin + g_w / 2
        g_y = ymin + g_h / 2

        t_x = (g_x - p_x) / p_w
        t_y = (g_y - p_y) / p_h
        t_w = np.log(g_w / p_w)
        t_h = np.log(g_h / p_h)

        return image, np.array((t_x, t_y, t_w, t_h))

 

다음으로 custom_batch_sampler.py 에서는 양성/음성 sample로 mini batch를 구성합니다. CustomBatchSampler 클래스 생성자에서 양성/음성 sample과 각각의 sample 수를 인자로 받습니다. 반복자(__iter__)에서 양성/음성 sample을 지정된 sample 수에 맞게 저장한 후 shuffle해줍니다. 

 

class CustomBatchSampler(Sampler):

    def __init__(self, num_positive, num_negative, batch_positive, batch_negative) -> None:

(...)

    def __iter__(self):
        sampler_list = list()
        for i in range(self.num_iter):
            tmp = np.concatenate(
                (random.sample(self.idx_list[:self.num_positive], self.batch_positive),
                 random.sample(self.idx_list[self.num_positive:], self.batch_negative))
            )
            random.shuffle(tmp)
            sampler_list.extend(tmp)

        return iter(sampler_list)
        
  (...)

 

다음으로 모델별로 정의한 Custom Dataset을 load하여 학습하는 과정을 살펴보겠습니다. 

3) pre-trained된 AlexNet fine tuning

3가지 모델 중 AlexNet을 가장 먼저 학습시켜줘야 합니다. linear SVM 및 Bounding box regressor 모델은 fine tune된 AlexNet에 학습 데이터를 입력시켜 얻은 feature vector를 통해 학습하기 때문입니다. 

 

  • finetune.py : 데이터를 load한 후 pre-trained된 AlexNet에 대한 fine tuning 수행

먼저 pytorch에서 제공하는 transforms 메서드를 통해 AlexNet의 입력 이미지에 맞게 크기를 227x227로 resize시켜줍니다. 그 다음 CustomFinetuneDataset를 통해 데이터를 load시켜줍니다. 이 때 논문에서 언급한 바와 같이 positive sample은 32개, negative sample은 96개가 되도록 CustomBatchSampler를 통해 지정합니다. 

 

if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_loaders, data_sizes = load_data('./data/finetune_car')

    model = models.alexnet(pretrained=True)
    num_features = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(num_features, 2)
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    best_model = train_model(data_loaders, model, criterion, optimizer, 
    						lr_scheduler, device=device, num_epochs=25)

    check_dir('./models')
    torch.save(best_model.state_dict(), 'models/alexnet_car.pth')

 

그리고 torchvision에서 제공하는 pre-trained된 AlexNet에서 마지막 layer에 fully connected layer를 추가시켜줍니다. 이 때 출력되는 output unit의 수는 예측하려는 class의 수 + 1(=배경)입니다. 저희가 살펴보는 코드는 오직 자동차에 대한 분류만 진행하기 때문에 output unit의 수는 2입니다. 

 

4) linear SVM 및 Bounding box regressor 학습

다음으로 linear SVM 및 Bounding box regressor 모델을 학습시키는 과정을 살펴보겠습니다. 

  • linear_svm.py : linear SVM 모델 학습(학습 시 hard negative mining 적용)
  • bbox_regression.py : Bounding box regressor 모델 학습

linear SVM 모델을 학습하 과정은 AlexNet을 fine tune하는 과정과 유사하지만 몇 가지 차이가 있습니다. 첫 번째로, 학습 초기 데이터를 구성하는 방법에서 차이가 있습니다. Object detection시, 일반적으로 분류하려는 class는 positive sample, 나머지 배경은 negative sample에 해당합니다. 이는 postive sample의 수가 negative sample보다 훨씬 더 적다는 것을 의미합니다. 이러한 클래스 불균형(class imbalance) 상황에서positive/negative sample을 모두 균형있게 학습하기 위해 학습 초기에는 positive sample과 negative sample의 비율을 1:1로 맞춰줍니다. 

 

두 번째로, Hinge loss를 loss function으로 사용합니다. Hinge loss function은 SVM 모델 학습 시 사용하는 loss function으로, $s_j$를 예측한 class의 confidence score, $s_{y_i}$를 실제 class의 confidence score라고 할 때 다음과 같은 수식으로 표현할 수 있습니다. 

 

$$\sum_{j \ne y_i} max(0, s_j - s_{y_i} + 1)$$

 

Hinge loss를 코드로 구현하면 아래와 같습니다. 

 

def hinge_loss(outputs, labels):
    num_labels = len(labels)
    corrects = outputs[range(num_labels), labels].unsqueeze(0).T

    # Maximum interval
    margin = 1.0
    margins = outputs - corrects + margin
    loss = torch.sum(torch.max(margins, 1)[0]) / len(labels)

    # regularization
    # reg = 1e-3
    # loss += reg * torch.sum(weight ** 2)
    return loss

 

마지막으로, 학습 시 Hard negative mining을 적용합니다. 앞서 살펴보았듯이 positive sample과 같은 수로 negative sample을 구성하다보니 많은 negative sample이 남게 됩니다. 이처럼 남은(remain) negative sample은 linear SVM 모델을 한 차례(1 epoch) 학습시킨 후 Hard negative mining에 사용합니다. 

 

남은 negative sample은 custom_hard_negative_mining_dataset.py 를 통해 데이터셋으로 구성해준 후 별도로 load시켜 준 후 학습시켜줍니다. 즉, 1 epoch 내에서 positive sample과 negative sample으로 학습시킨 후, 남은 negative sample만으로 모델을 다시 학습시켜주는 셈입니다. 학습 결과 모델이 negative(=배경)이라고 정확히 예측한 경우 True Positive sample이며, positive라고 예측했으나, 실제로는 negative인 경우에는 False Positive sample입니다. 아래는 학습 결과를 토대로 hard negative sample(=False Positive)과 easy negative sample(=True Negative)로 구분하는 코드입니다.  

 

def get_hard_negatives(preds, cache_dicts):
    fp_mask = preds == 1
    tn_mask = preds == 0

    fp_rects = cache_dicts['rect'][fp_mask].numpy()
    fp_image_ids = cache_dicts['image_id'][fp_mask].numpy()

    tn_rects = cache_dicts['rect'][tn_mask].numpy()
    tn_image_ids = cache_dicts['image_id'][tn_mask].numpy()

    hard_negative_list = [{'rect': fp_rects[idx], 'image_id': fp_image_ids[idx]} for idx in range(len(fp_rects))]
    easy_negative_list = [{'rect': tn_rects[idx], 'image_id': tn_image_ids[idx]} for idx in range(len(tn_rects))]

    return hard_negative_list, easy_negative_list

 

모델을 남은 negative sample을 학습시켜 찾은 hard negative sample은 negative sample에 추가하여 CustomBatchSampler와 Dataloader를 통해 학습 데이터셋을 다시 구성해줍니다. 즉, 다음 epoch에서는 hard negative sample이 추가된 데이터로 모델을 학습시킬 수 있습니다. 

 

Bounding box regressor를 학습시키는 과정은 상대적으로 단순합니다. fine tuned된 AlexNet 마지막 layer에 output unit = 4인 fully connected layer를 추가해줍니다. 이 때 loss function은 MSE(Mean Squared Error)입니다. 

 

if __name__ == '__main__':
    data_loader = load_data('./data/bbox_regression')

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    feature_model = get_model(device)
    
    in_features = 256 * 6 * 6
    out_features = 4
    model = nn.Linear(in_features, out_features)
    model.to(device)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    loss_list = train_model(data_loader, feature_model, model, criterion, optimizer, lr_scheduler, device=device,
                            num_epochs=12)
    util.plot_loss(loss_list)

 

마지막으로 지금까지 학습시킨 3가지 모델을 모두 사용하여 Object detection을 수행하는 코드를 살펴보도록 하겠습니다. 

5) 학습시킨 모델을 활용하여 detection 수행

  • car_detector.py : fine tuned AlexNet, linear SVM, Bounding box regressor 모델을 활용하여 detection 수행. Non maximum suppression 알고리즘 적용.

먼저 fine tuned된 AlexNet을 통해 얻은 feature vector를 linear SVM 모델에 입력하여 confidence score를 얻습니다. 이 때 svm_threshold = 0.6 으로 지정하여 임계값보다 큰 confidence score를 가지는 bounding box만을 저장합니다. 이후 Non max suppression 알고리즘을 적용하여 최종 detection 결과를 얻습니다. 

 

def nms(rect_list, score_list):
    """
    Non-maximum suppression
    :param rect_list: list, the size is [N, 4]
    :param score_list: list, size is [N]
    """
    nms_rects = list()
    nms_scores = list()

    rect_array = np.array(rect_list)
    score_array = np.array(score_list)

    # After sorting once
    # Sort by classification probability from largest to smallest
    idxs = np.argsort(score_array)[::-1]
    rect_array = rect_array[idxs]
    score_array = score_array[idxs]

    thresh = 0.3
    while len(score_array) > 0:
        # Add the bounding box with the highest classification probability
        nms_rects.append(rect_array[0])
        nms_scores.append(score_array[0])
        rect_array = rect_array[1:]
        score_array = score_array[1:]

        length = len(score_array)
        if length <= 0:
            break

        iou_scores = util.iou(np.array(nms_rects[len(nms_rects) - 1]), rect_array)
        # Remove bounding boxes with overlap ratio greater than or equal to thresh
        idxs = np.where(iou_scores < thresh)[0]
        rect_array = rect_array[idxs]
        score_array = score_array[idxs]

    return nms_rects, nms_scores

 

여기서 IoU treshold = 0.3 으로 지정했습니다. 위의 코드를 보면 임계값보다 작은 bounding box의 index만을 남기는 방식으로 Non maximum suppression 알고리즘을 구현하고 있습니다. 


지금까지 pytorch로 구현한 R-CNN 모델의 코드를 살펴보았습니다. 코드 양이 많아 제가 중요하다고 생각하거나, 간과하기 쉬운 부분만 살펴보았습니다. 개인적으로 모델별로 별도의 annotation(csv 파일)을 생성하고 Custom Dataset을 정의하는 부분이 상당히 까다로웠던 것 같습니다. 그리고 linear SVM 모델을 구현한 코드를 보면서 논문에서는 구체적으로 언급하지 않은 Hard negative mining의 동작 원리를 명확하게 알게 된 것 같습니다. 개인적으로 numpy를 활용하여 조건에 맞는 index만을 추출하여 Non max suppression 알고리즘을 구현한 체가 흥미로웠습니다. 그 외에도 PASCAL VOC 데이터셋을 다운받고 파싱하는 과정을 살펴볼 수 있어 좋은 경험이었던 것 같습니다. 

 

최신 모델에 대한 구현 코드는 상대적으로 쉽게 찾을 수 있었으나, R-CNN 모델은 상대적으로 오래된 모델인지라 코드를 찾기 어려웠습니다. 운 좋게 object-detection-algorithm님이 구현한 코드를 찾아 공부할 수 있었고 논문에서 자세히 언급하지 않았던 부분을 명확히 이해할 수 있어서 좋았습니다. 다만 모델마다 별도의 모듈을 구성하여 Custom Dataset 을 정의하고 있는데, 개인적으로 positive/negative sample에 대한 비율을 생성자(__init__)로 받게 하여 하나의 Custom Dataset으로 정의하면 어떨까 하는 생각이 들었습니다. 시간 있을 때 코드를 수정해볼 계획입니다. 

 

저번 포스팅부터 시작해서 R-CNN 모델에 대한 이론과 구현 코드를 살펴보았습니다. 블로그에 글을 올리면 많은 분들이 볼 수도 있다는 생각에 최대한 자세히 논문을 읽고, 코드를 분석했던 것 같습니다🤣. 

Reference

object-detection-algorithm님이 pytorch로 구현한 R-CNN

PASCAL VOC 폴더 계층 구조

R-CNN 논문 리뷰 포스팅

Comments