#include "tinyexr.h"

#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"

#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"

#include <array>
#include <cmath>
#include <iostream>
#include <string>
#include <vector>

// From Filament.
static inline void RGBMtoLinear(const float rgbm[4], float linear[3]) {
  linear[0] = rgbm[0] * rgbm[3] * 16.0f;
  linear[1] = rgbm[1] * rgbm[3] * 16.0f;
  linear[2] = rgbm[2] * rgbm[3] * 16.0f;

  // Gamma to linear space
  linear[0] = linear[0] * linear[0];
  linear[1] = linear[1] * linear[1];
  linear[2] = linear[2] * linear[2];
}

static inline void LinearToRGBM(const float linear[3], float rgbm[4]) {
  rgbm[0] = linear[0];
  rgbm[1] = linear[1];
  rgbm[2] = linear[2];
  rgbm[3] = 1.0f;

  // Linear to gamma space
  rgbm[0] = rgbm[0] * rgbm[0];
  rgbm[1] = rgbm[1] * rgbm[1];
  rgbm[2] = rgbm[2] * rgbm[2];

  // Set the range
  rgbm[0] /= 16.0f;
  rgbm[1] /= 16.0f;
  rgbm[2] /= 16.0f;

  float maxComponent =
      std::max(std::max(rgbm[0], rgbm[1]), std::max(rgbm[2], 1e-6f));
  // Don't let M go below 1 in the [0..16] range
  rgbm[3] = std::max(1.0f / 16.0f, std::min(maxComponent, 1.0f));
  rgbm[3] = std::ceil(rgbm[3] * 255.0f) / 255.0f;

  // saturate([0.0, 1.0])
  rgbm[0] = std::max(0.0f, std::min(1.0f, rgbm[0] / rgbm[3]));
  rgbm[1] = std::max(0.0f, std::min(1.0f, rgbm[1] / rgbm[3]));
  rgbm[2] = std::max(0.0f, std::min(1.0f, rgbm[2] / rgbm[3]));
}

static std::string GetFileExtension(const std::string& filename) {
  if (filename.find_last_of(".") != std::string::npos)
    return filename.substr(filename.find_last_of(".") + 1);
  return "";
}

struct Image {
  int width;
  int height;
  std::vector<float> data;
};

static bool LoadCubemaps(const std::array<std::string, 6> face_filenames,
                         std::array<Image, 6>* output) {
  for (size_t i = 0; i < 6; i++) {
    std::string ext = GetFileExtension(face_filenames[i]);

    Image image;

    if ((ext.compare("exr") == 0) || (ext.compare("EXR") == 0)) {
      int width, height;
      float* rgba;
      const char* err;

      int ret =
          LoadEXR(&rgba, &width, &height, face_filenames[i].c_str(), &err);
      if (ret != 0) {
        if (err) {
          std::cerr << "EXR load error: " << err << std::endl;
        } else {
          std::cerr << "EXR load error: code " << ret << std::endl;
        }
        return false;
      }

      image.width = width;
      image.height = height;
      image.data.resize(width * height * 3);

      // RGBA -> RGB
      for (size_t j = 0; j < size_t(width * height); j++) {
        image.data[3 * j + 0] = rgba[4 * j + 0];
        image.data[3 * j + 1] = rgba[4 * j + 1];
        image.data[3 * j + 2] = rgba[4 * j + 2];
      }

      free(rgba);

      (*output)[i] = std::move(image);

    } else if ((ext.compare("rgbm") == 0) || (ext.compare("RGBM") == 0)) {
      int width, height;
      int n;

      unsigned char* data = stbi_load(face_filenames[i].c_str(), &width,
                                      &height, &n, STBI_default);

      if (!data) {
        std::cerr << "Failed to load file: " << face_filenames[i] << std::endl;
        return false;
      }

      if ((n != 4)) {
        std::cerr << "Not a RGBM encoded image: " << face_filenames[i]
                  << std::endl;
        return false;
      }

      image.width = width;
      image.height = height;
      image.data.resize(size_t(width * height));

      for (size_t i = 0; i < size_t(width * height); i++) {
        float rgbm[4];
        // [0, 1.0]
        rgbm[0] = data[4 * i + 0] / 255.0f;
        rgbm[1] = data[4 * i + 1] / 255.0f;
        rgbm[2] = data[4 * i + 2] / 255.0f;
        rgbm[3] = data[4 * i + 3] / 255.0f;

        float linear[3];
        RGBMtoLinear(rgbm, linear);

        image.data[3 * i + 0] = linear[0];
        image.data[3 * i + 1] = linear[1];
        image.data[3 * i + 2] = linear[2];
      }

      (*output)[i] = std::move(image);

    } else {
      std::cerr << "Unknown file extension : " << ext << std::endl;
      return false;
    }
    std::cout << "Loaded " << face_filenames[i] << std::endl;
  }

  return true;
}

void convert_xyz_to_cube_uv(float x, float y, float z, int* index, float* u,
                            float* v) {
  float absX = fabs(x);
  float absY = fabs(y);
  float absZ = fabs(z);

  int isXPositive = x > 0.0f ? 1 : 0;
  int isYPositive = y > 0.0f ? 1 : 0;
  int isZPositive = z > 0.0f ? 1 : 0;

  float maxAxis, uc, vc;

  // POSITIVE X
  if (isXPositive && absX >= absY && absX >= absZ) {
    // u (0 to 1) goes from +z to -z
    // v (0 to 1) goes from -y to +y
    maxAxis = absX;
    uc = -z;
    vc = y;
    *index = 0;
  }
  // NEGATIVE X
  if (!isXPositive && absX >= absY && absX >= absZ) {
    // u (0 to 1) goes from -z to +z
    // v (0 to 1) goes from -y to +y
    maxAxis = absX;
    uc = z;
    vc = y;
    *index = 1;
  }
  // POSITIVE Y
  if (isYPositive && absY >= absX && absY >= absZ) {
    // u (0 to 1) goes from -x to +x
    // v (0 to 1) goes from +z to -z
    maxAxis = absY;
    uc = x;
    vc = -z;
    *index = 2;
  }
  // NEGATIVE Y
  if (!isYPositive && absY >= absX && absY >= absZ) {
    // u (0 to 1) goes from -x to +x
    // v (0 to 1) goes from -z to +z
    maxAxis = absY;
    uc = x;
    vc = z;
    *index = 3;
  }
  // POSITIVE Z
  if (isZPositive && (absZ >= absX) && (absZ >= absY)) {
    // u (0 to 1) goes from -x to +x
    // v (0 to 1) goes from -y to +y
    maxAxis = absZ;
    uc = x;
    vc = y;
    *index = 4;
  }
  // NEGATIVE Z
  if (!isZPositive && (absZ >= absX) && (absZ >= absY)) {
    // u (0 to 1) goes from +x to -x
    // v (0 to 1) goes from -y to +y
    maxAxis = absZ;
    uc = -x;
    vc = y;
    *index = 5;
  }

  // Convert range from -1 to 1 to 0 to 1
  *u = 0.5f * (uc / maxAxis + 1.0f);
  *v = 0.5f * (vc / maxAxis + 1.0f);
}

//
// Simple bilinear texture filtering.
//
static void SampleTexture(float* rgba, float u, float v, int width, int height,
                          int channels, const float* texels) {
  float sx = std::floor(u);
  float sy = std::floor(v);

  // Wrap mode = repeat
  float uu = u - sx;
  float vv = v - sy;

  // clamp
  uu = std::max(uu, 0.0f);
  uu = std::min(uu, 1.0f);
  vv = std::max(vv, 0.0f);
  vv = std::min(vv, 1.0f);

  float px = (width - 1) * uu;
  float py = (height - 1) * vv;

  int x0 = std::max(0, std::min((int)px, (width - 1)));
  int y0 = std::max(0, std::min((int)py, (height - 1)));
  int x1 = std::max(0, std::min((x0 + 1), (width - 1)));
  int y1 = std::max(0, std::min((y0 + 1), (height - 1)));

  float dx = px - (float)x0;
  float dy = py - (float)y0;

  float w[4];

  w[0] = (1.0f - dx) * (1.0 - dy);
  w[1] = (1.0f - dx) * (dy);
  w[2] = (dx) * (1.0 - dy);
  w[3] = (dx) * (dy);

  int i00 = channels * (y0 * width + x0);
  int i01 = channels * (y0 * width + x1);
  int i10 = channels * (y1 * width + x0);
  int i11 = channels * (y1 * width + x1);

  for (int i = 0; i < channels; i++) {
    rgba[i] = w[0] * texels[i00 + i] + w[1] * texels[i10 + i] +
              w[2] * texels[i01 + i] + w[3] * texels[i11 + i];
  }
}

static void SampleCubemap(const std::array<Image, 6>& cubemap_faces,
                          const float n[3], float col[3]) {
  int face;
  float u, v;
  convert_xyz_to_cube_uv(n[0], n[1], n[2], &face, &u, &v);

  v = 1.0f - v;

  // std::cout << "face = " << face << std::endl;

  // TODO(syoyo): Do we better consider seams on the cubemap face border?
  const Image& tex = cubemap_faces[face];

  // std::cout << "n = " << n[0] << ", " << n[1] << ", " << n[2] << ", uv = " <<
  // u << ", " << v << std::endl;

  SampleTexture(col, u, v, tex.width, tex.height, /* RGB */ 3, tex.data.data());

// col[0] = u;
// col[1] = v;
// col[2] = 0.0f;
#if 0
  if (face == 0) {
    col[0] = 1.0f; 
    col[1] = 0.0f; 
    col[2] = 0.0f; 
  } else if (face == 1) {
    col[0] = 0.0f; 
    col[1] = 1.0f; 
    col[2] = 0.0f; 
  } else if (face == 2) {
    col[0] = 0.0f; 
    col[1] = 0.0f; 
    col[2] = 1.0f; 
  } else if (face == 3) {
    col[0] = 1.0f; 
    col[1] = 0.0f; 
    col[2] = 1.0f; 
  } else if (face == 4) {
    col[0] = 0.0f; 
    col[1] = 1.0f; 
    col[2] = 1.0f; 
  } else if (face == 5) {
    col[0] = 1.0f; 
    col[1] = 1.0f; 
    col[2] = 1.0f; 
  }
#endif
}

static void CubemapToLonglat(const std::array<Image, 6>& cubemap_faces,
                             const float phi_offset, /* in angle */
                             const int width, Image* longlat) {
  int height = width / 2;

  longlat->width = width;
  longlat->height = height;
  longlat->data.resize(size_t(width * height * 3));  // RGB

  const float kPI = 3.141592f;

  for (size_t y = 0; y < size_t(height); y++) {
    float theta = ((y + 0.5f) / float(height)) * kPI;  // [0, pi]
    for (size_t x = 0; x < size_t(width); x++) {
      float phi = ((x + 0.5f) / float(width)) * 2.0f * kPI;  // [0, 2 pi]

      phi += (phi_offset) * kPI / 180.0f;

      float n[3];

      // Y-up
      n[0] = std::sin(theta) * std::cos(phi);
      n[1] = std::cos(theta);
      n[2] = -std::sin(theta) * std::sin(phi);

      float col[3];
      SampleCubemap(cubemap_faces, n, col);

      longlat->data[3 * size_t(y * width + x) + 0] = col[0];
      longlat->data[3 * size_t(y * width + x) + 1] = col[1];
      longlat->data[3 * size_t(y * width + x) + 2] = col[2];
    }
  }
}

static unsigned char ftouc(const float f) {
  int i(f * 255.0f);
  i = std::max(0, std::min(255, i));
  return static_cast<unsigned char>(i);
}

int main(int argc, char** argv) {
  float phi_offset = 0.0f;

  if (argc < 9) {
    printf(
        "Usage: cube2longlat px.exr nx.exr py.exr ny.exr pz.exr nz.exr "
        "output_width output.exr\n");
    exit(-1);
  }

  std::array<std::string, 6> face_filenames;

  face_filenames[0] = argv[1];
  face_filenames[1] = argv[2];
  face_filenames[2] = argv[3];
  face_filenames[3] = argv[4];
  face_filenames[4] = argv[5];
  face_filenames[5] = argv[6];

  int output_width = atoi(argv[7]);

  std::string output_filename = argv[8];

  if (argc > 9) {
    phi_offset = atof(argv[9]);
  }

  std::array<Image, 6> cubemaps;

  if (!LoadCubemaps(face_filenames, &cubemaps)) {
    std::cerr << "Failed to load cubemap faces." << std::endl;
    return EXIT_FAILURE;
  }

  Image longlat;

  CubemapToLonglat(cubemaps, phi_offset, output_width, &longlat);

  {
    std::string ext = GetFileExtension(output_filename);
    if ((ext.compare("exr") == 0) || (ext.compare("EXR") == 0)) {
      const char *err;
      int ret = SaveEXR(longlat.data.data(), longlat.width, longlat.height,
                        /* RGB */ 3, /* fp16 */ 0, output_filename.c_str(), &err);
      if (ret != TINYEXR_SUCCESS) {
        if (err) {
          std::cout << "Failed to save image as EXR. msg = " << err << ", code = " << ret << std::endl;
          FreeEXRErrorMessage(err);
        } else {
          std::cout << "Failed to save image as EXR. code = " << ret << std::endl;
        }
        return EXIT_FAILURE;
      }
    } else if ((ext.compare("rgbm") == 0) || (ext.compare("RGBM") == 0)) {
      std::vector<unsigned char> rgbm_image;

      for (size_t j = 0; j < size_t(longlat.width * longlat.height); j++) {
        float linear[3];
        linear[0] = longlat.data[3 * j + 0];
        linear[1] = longlat.data[3 * j + 1];
        linear[2] = longlat.data[3 * j + 2];

        float rgbm[4];

        LinearToRGBM(linear, rgbm);

        rgbm_image[4 * j + 0] = ftouc(rgbm[0]);
        rgbm_image[4 * j + 1] = ftouc(rgbm[1]);
        rgbm_image[4 * j + 2] = ftouc(rgbm[2]);
        rgbm_image[4 * j + 3] = ftouc(rgbm[2]);
      }

      // Save as PNG.
      int ret =
          stbi_write_png(output_filename.c_str(), longlat.width, longlat.height,
                         4, rgbm_image.data(), longlat.width * 4);

      if (ret == 0) {
        std::cerr << "Failed to save image as RGBM file : " << output_filename
                  << std::endl;
        return EXIT_FAILURE;
      }

    } else {
      if ((ext.compare("hdr") == 0) || (ext.compare("HDR") == 0)) {
        // ok
      } else {
        std::cout << "Unknown file extension. Interpret it as RGBE format : "
                  << ext << std::endl;
      }

      int ret = stbi_write_hdr(output_filename.c_str(), longlat.width,
                               longlat.height, 3, longlat.data.data());

      if (ret == 0) {
        std::cerr << "Failed to save image as HDR file : " << output_filename
                  << std::endl;
        return EXIT_FAILURE;
      }
    }
  }

  std::cout << "Write " << output_filename << std::endl;

  return 0;
}