submit code
This commit is contained in:
43
tools/cat_images.py
Normal file
43
tools/cat_images.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import pathlib
|
||||
import argparse
|
||||
|
||||
def group_images(path_list):
|
||||
sorted(path_list)
|
||||
class_id_dict = {}
|
||||
for path in path_list:
|
||||
class_id = str(path.name).split('_')[0]
|
||||
if class_id not in class_id_dict:
|
||||
class_id_dict[class_id] = []
|
||||
class_id_dict[class_id].append(path)
|
||||
return class_id_dict
|
||||
|
||||
def cat_images(path_list):
|
||||
imgs = []
|
||||
for path in path_list:
|
||||
img = cv2.imread(str(path))
|
||||
os.remove(path)
|
||||
imgs.append(img)
|
||||
row_cat_images = []
|
||||
row_length = int(len(imgs)**0.5)
|
||||
for i in range(len(imgs)//row_length):
|
||||
row_cat_images.append(np.concatenate(imgs[i*row_length:(i+1)*row_length], axis=1))
|
||||
cat_image = np.concatenate(row_cat_images, axis=0)
|
||||
return cat_image
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--src_dir', type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
src_dir = args.src_dir
|
||||
path_list = list(pathlib.Path(src_dir).glob('*.png'))
|
||||
class_id_dict = group_images(path_list)
|
||||
for class_id, path_list in class_id_dict.items():
|
||||
cat_image = cat_images(path_list)
|
||||
cat_path = os.path.join(src_dir, f'cat_{class_id}.jpg')
|
||||
# cat_path = "cat_{}.png".format(class_id)
|
||||
cv2.imwrite(cat_path, cat_image)
|
||||
|
||||
Reference in New Issue
Block a user