image/main.py

248 lines
9.9 KiB
Python

import simpleimageio as sio
import figuregen
from figuregen.util.templates import FullSizeWithCrops
from figuregen.util.image import Cropbox
import os
import argparse
import yaml
import numpy as np
class MyCrops(FullSizeWithCrops):
def __init__(self, reference_image, method_images, crops,
crops_below = True, method_names = None, use_latex = False,add=False):
""" Shows a reference image next to a grid of crops from different methods.
Args:
reference_image: a reference image (or any other image to put full-size in the lefthand grid)
method_images: list of images, each corresponds to a new column in the crop grid
crops: list of crops to take from each method, each creates a new row and a marker on the reference
crops_below: [optional] if False, the crops will be a column to the right of each image
method_names: [optional] list of string, names for the reference and each method, to put above the crops
use_latex: set to true to pretty-print captions with LaTeX commands (requires TikZ backend)
Returns:
A list of two grids:
The first is a single image (reference), the second a series of crops, one or more for each method.
"""
self._reference_image = reference_image
self._method_images = method_images
self.use_latex = use_latex
self._crops_below = crops_below
self._errors = [
self.compute_error(reference_image, m)
for m in method_images
]
self._crop_errors = [
[
self.compute_error(crop.crop(reference_image), crop.crop(m))
for m in method_images
]
for crop in crops
]
# Put in one list to make our life easier in the following
images = [reference_image]
images.extend(method_images)
# Create the grid for the reference image
self._ref_grid = [ figuregen.Grid(1, 1) for _ in range(len(images)) ]
for i in range(len(images)):
self._ref_grid[i][0, 0].image = self.tonemap(images[i])
for crop in crops:
self._ref_grid[i][0, 0].set_marker(crop.marker_pos, crop.marker_size, color=[255,255,255])
# Create the grid with the crops
if self._crops_below:
self._crop_grid = [
figuregen.Grid(num_cols=len(crops), num_rows=1)
for _ in range(len(images))
]
for i in range(len(images)):
for col in range(len(crops)):
self._crop_grid[i][0, col].image = self.tonemap(crops[col].crop(images[i]))
else:
self._crop_grid = [
figuregen.Grid(num_cols=1, num_rows=len(crops))
for _ in range(len(images))
]
for i in range(len(images)):
for row in range(len(crops)):
self._crop_grid[i][row, 0].image = self.tonemap(crops[row].crop(images[i]))
# Add padding to the right of all but the last image
for i in range(len(images) - 1):
self._ref_grid[i].layout.set_padding(right=1)
self._crop_grid[i].layout.set_padding(right=1)
if self._crops_below:
self._ref_grid[i].layout.set_padding(bottom=1)
if self._crops_below:
self._ref_grid[-1].layout.set_padding(bottom=1)
else:
self._ref_grid[-1].layout.set_padding(right=1)
# Put error values underneath the columns
if self._crops_below:
for i in range(len(images)):
# if i > 0:
# err = self.error_string(i - 1, self.errors)
# else:
# err = self.error_metric_name
if add:
self._crop_grid[i].set_title("bottom", method_names[i])
self._crop_grid[i].layout.set_title("bottom", 6, 1, 8)
else:
pass # TODO
# TODO this requires titles spanning multiple grids (the image and its crops)!
# error_strings = [ f"{self.error_metric_name}" ]
# error_strings.extend([ self.error_string(i, self.errors) for i in range(len(self.errors)) ])
# self._crop_grid.set_col_titles("bottom", error_strings)
# self._crop_grid.layout.set_padding(column=1, row=1)
# self._crop_grid.layout.set_col_titles("bottom", fontsize=8, field_size_mm=2.8, offset_mm=0.5)
# If given, show method names on top
# TODO combine with error values, and always show both or neither
# if method_names is not None:
# self._crop_grid.set_col_titles("top", method_names)
# self._crop_grid.layout.set_col_titles("top", fontsize=8, field_size_mm=2.8, offset_mm=0.25)
# self._ref_grid.copy_layout(self._crop_grid)
# self._ref_grid.layout.set_padding(right=1)
# TODO set appropriate paddings for alignment etc
def error_string(self, index, errors):
""" Generates the human-readable error string for the i-th element in a list of error values.
Args:
index: index in the list of errors
errors: list of error values, one per method, in order
"""
if self.use_latex and index == np.argmin(errors):
return f""
elif self.use_latex:
return f""
else:
return f""
def error_metric_name(self) -> str:
return ""
def main(data_root,image_file_1,image_file_2,image_file_3,dataset_type,ours_method,crops_1,crops_2,crops_3,width_cm,output_image,ours_method_1=None,ours_method_2=None,ours_method_3=None):
data_root = os.path.join(data_root)
original_image_1 = sio.read(os.path.join(data_root,"original",dataset_type,image_file_1))
# 写完了你在这个里面直接运行就可以
KinD_image_1= sio.read(os.path.join(data_root,"KinD",dataset_type,image_file_1))
LIME_image_1= sio.read(os.path.join(data_root,"LIME",dataset_type,image_file_1))
NPE_image_1 = sio.read(os.path.join(data_root,"NPE",dataset_type,image_file_1))
SRIE_image_1 = sio.read(os.path.join(data_root,"SRIE",dataset_type,image_file_1))
ZeroDCE_image_1 = sio.read(os.path.join(data_root,"ZeroDCE",dataset_type,image_file_1))
ours_image_1 = sio.read(os.path.join(data_root,ours_method_1 if ours_method_1 else ours_method,dataset_type,image_file_1))
original_image_2 = sio.read(os.path.join(data_root,"original",dataset_type,image_file_2))
# 写完了你在这个里面直接运行就可以
KinD_image_2= sio.read(os.path.join(data_root,"KinD",dataset_type,image_file_2))
LIME_image_2= sio.read(os.path.join(data_root,"LIME",dataset_type,image_file_2))
NPE_image_2 = sio.read(os.path.join(data_root,"NPE",dataset_type,image_file_2))
SRIE_image_2 = sio.read(os.path.join(data_root,"SRIE",dataset_type,image_file_2))
ZeroDCE_image_2 = sio.read(os.path.join(data_root,"ZeroDCE",dataset_type,image_file_2))
ours_image_2 = sio.read(os.path.join(data_root,ours_method_2 if ours_method_2 else ours_method,dataset_type,image_file_2))
original_image_3 = sio.read(os.path.join(data_root,"original",dataset_type,image_file_3))
# 写完了你在这个里面直接运行就可以
KinD_image_3= sio.read(os.path.join(data_root,"KinD",dataset_type,image_file_3))
LIME_image_3= sio.read(os.path.join(data_root,"LIME",dataset_type,image_file_3))
NPE_image_3 = sio.read(os.path.join(data_root,"NPE",dataset_type,image_file_3))
SRIE_image_3 = sio.read(os.path.join(data_root,"SRIE",dataset_type,image_file_3))
ZeroDCE_image_3 = sio.read(os.path.join(data_root,"ZeroDCE",dataset_type,image_file_3))
ours_image_3 = sio.read(os.path.join(data_root,ours_method_3 if ours_method_3 else ours_method,dataset_type,image_file_3))
output_image = os.path.abspath(output_image)
figure_1 = MyCrops(
reference_image=original_image_1,
method_images=[
KinD_image_1,
LIME_image_1,
NPE_image_1,
SRIE_image_1,
ZeroDCE_image_1,
ours_image_1
],
crops=[
Cropbox(top=crop[0], left=crop[1], height=crop[2], width=crop[3], scale=crop[4]) for crop in crops_1
],
# scene_name="Pool",
method_names=["Reference", "KinD", "LIME", "NPE", "SRIE", "ZeroDCE", "Ours"],
use_latex=True
).figure
figure_2 = MyCrops(
reference_image=original_image_2,
method_images=[
KinD_image_2,
LIME_image_2,
NPE_image_2,
SRIE_image_2,
ZeroDCE_image_2,
ours_image_2
],
crops=[
Cropbox(top=crop[0], left=crop[1], height=crop[2], width=crop[3], scale=crop[4]) for crop in crops_2
],
# scene_name="Pool",
method_names=["Reference", "KinD", "LIME", "NPE", "SRIE", "ZeroDCE", "Ours"],
use_latex=True
).figure
figure_3=MyCrops(
reference_image=original_image_3,
method_images=[
KinD_image_3,
LIME_image_3,
NPE_image_3,
SRIE_image_3,
ZeroDCE_image_3,
ours_image_3
],
crops=[
Cropbox(top=crop[0], left=crop[1], height=crop[2], width=crop[3], scale=crop[4]) for crop in crops_3
],
method_names=["Reference", "KinD", "LIME", "NPE", "SRIE", "ZeroDCE", "Ours"],
use_latex=True,
add=True
).figure
# scene_name="Pool",
figuregen.figure([figure_1[0],figure_1[1],figure_2[0],figure_2[1],figure_3[0],figure_3[1]], width_cm=width_cm, filename=output_image)
# grid = figuregen.Grid(3,1)
# grid[0,0].set_image(figure)
# grid[1,0].set_image(figure)
# grid[2,0].set_image(figure)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./config/compare3_1.yaml")
args = parser.parse_args()
with open(args.config, "r") as f:
config = yaml.safe_load(f)
main(**config)