///////////////////////////////////////////////////////////////////////////////
//
//   jpeg をデコードする
//

#include <stdlib.h>
#include <string.h>

#include "jdecoder.h"


using namespace JPEG;

static unsigned int zigzag[] = {
  0,  1,  8, 16,  9,  2,  3, 10,
 17, 24, 32, 25, 18, 11,  4,  5,
 12, 19, 26, 33, 40, 48, 41, 34,
 27, 20, 13,  6,  7, 14, 21, 28,
 35, 42, 49, 56, 57, 50, 43, 36,
 29, 22, 15, 23, 30, 37, 44, 51,
 58, 59, 52, 45, 38, 31, 39, 46,
 53, 60, 61, 54, 47, 55, 62, 63
};


static
void print_bits( unsigned int code, unsigned int bits ) {
  if ( 1 < bits )
    print_bits( code >> 1, bits - 1);
  printf("%c", code & 1 ? '1' : '0');
}



static inline void _bezero(void *m, unsigned int size) {
  unsigned long *im = (unsigned long *)m;
  unsigned char *cm;

  while (size >= sizeof(unsigned long)) {
    *im++ = 0;
    size -= sizeof(unsigned long);
  }

  cm = (unsigned char *)im;
  while (size--)
    *cm++ = 0;
}



//-----------------------------------------------------------------------------
/*! バッファをコピーする
 */
static inline
void _copy(void *dst, unsigned int dsize, const void *src ) {

  unsigned int size = dsize;
  unsigned long *ld = (unsigned long *)dst;
  unsigned long *ls = (unsigned long *)src;

  while ( sizeof(unsigned long) <= size ) {
    *ld++ = *ls++;
    size -= sizeof(unsigned long);
  }

  unsigned char *cd = (unsigned char *)ld;
  unsigned char *cs = (unsigned char *)ls;

  while (size--) {
    *cd++ = *cs++;
  }

}


static inline
int _get_max(int *begin, int *over) {
  int max = 0;
  while (begin < over) {
    if ( max < *begin ) max = *begin;
    ++begin;
  }
  return max;
}

static inline
int _get_min(int *begin, int *over) {
  int min = 0x7fffffff;
  while (begin < over) {
    if (*begin < min) min = *begin;
    ++begin;
  }
  return min;
}



//=============================================================================
const char *JPEG::getMarkerDescription(unsigned int m) {

  m &= 0xff;

  if (0xf0 <= m && m <= 0xfd) return "JPG - Reserved (JPEG Extention)";

  switch(m) {
  case 0x00: return "Just a 0xFF";
  case 0x01: return "TEM - For Temporary Use";
  case 0x02:
  case 0xbf: return "Reserved";
  case 0xc0: return "SOF0 - Start of Frame / Baseline DCT, Huffman";
  case 0xc1: return "SOF1 - Start of Frame / Extended sequential DCT, Huffman";
  case 0xc2: return "SOF2 - Strat of Frame / Progressive DCT, Huffman";
  case 0xc3: return "SOF3 - Start of Frame / Spatial (Sequential) lossless, Huffman";
  case 0xc4: return "DHT - Define Huffman Table(s)";
  case 0xc5: return "SOF5 - Start of Frame / Differential sequential DCT, Huffman";
  case 0xc6: return "SOF6 - Start of Frame / Differential progressive DCT, Huffman";
  case 0xc7: return "SOF7 - Start of Frame / Differential lossless, Huffman";
  case 0xc8: return "JPG - Reserved for JPEG Extentions, Arithmetic";
  case 0xc9: return "SOF9 - Start of Frame / Extended sequential DCT, Arithmetic";
  case 0xca: return "SOF10 - Start of Frame / Progressive DCT, Arithmetic";
  case 0xcb: return "SOF11 - Start of Frame / Spatial (sequential) lossless. Arithmetic";
  case 0xcc: return "DAC - Define Arithmetic coding conditioning";
  case 0xcd: return "SOF13 - Start of Frame / Differential sequential DCT, Arithmetic";
  case 0xce: return "SOF14 - Start of Frame / Differential progressive DCT, Arithmetic";
  case 0xcf: return "SOF15 - Start of Frame / Differential spatial, Arithmetic";
  case 0xd0: return "RST0 - Restart Interval 0";
  case 0xd1: return "RST1 - Restart Interval 1";
  case 0xd2: return "RST2 - Restart Interval 2";
  case 0xd3: return "RST3 - Restart Interval 3";
  case 0xd4: return "RST4 - Restart Interval 4";
  case 0xd5: return "RST5 - Restart Interval 5";
  case 0xd6: return "RST6 - Restart Interval 6";
  case 0xd7: return "RST7 - Restart Interval 7";
  case 0xd8: return "SOI - Start of Image";
  case 0xd9: return "EOI - End of Image";
  case 0xda: return "SOS - Start of Scan";
  case 0xdb: return "DQT - Define Quantization Tables";
  case 0xdc: return "DNL - Define Number of Lines";
  case 0xdd: return "DRI - Define Restart Interval";
  case 0xde: return "DHP - Define Hierarchical progression";
  case 0xdf: return "EXP - Expand reference components";
  case 0xe0: return "APP0 - Application 0";
  case 0xe1: return "APP1 - Application 1";
  case 0xe2: return "APP2 - Application 2";
  case 0xe3: return "APP3 - Application 3";
  case 0xe4: return "APP4 - Application 4";
  case 0xe5: return "APP5 - Application 5";
  case 0xe6: return "APP6 - Application 6";
  case 0xe7: return "APP7 - Application 7";
  case 0xe8: return "APP8 - Application 8";
  case 0xe9: return "APP9 - Application 9";
  case 0xea: return "APP10 - Application 10";
  case 0xeb: return "APP11 - Application 11";
  case 0xec: return "APP12 - Application 12";
  case 0xed: return "APP13 - Application 13";
  case 0xee: return "APP14 - Application 14";
  case 0xef: return "APP15 - Application 15";
  case 0xfe: return "COM - Comment";
  }

  return 0;
}


//=============================================================================
/*! コンストラクタ  InputStream 型を引数に
 */
Decoder::Decoder( InputStream *stream ) :
  stream_(stream)
{
  image_started_ = false;
  image_end_ = false;
  start_of_frame_ = false;

  decoded_image_ = 0;
  coeff_workarea_[0] =
  coeff_workarea_[1] =
  coeff_workarea_[2] = 0;

  getting_line_ = 0;
  dc_pred_[0] =
  dc_pred_[1] =
  dc_pred_[2] = 0;
  skip_next_0x00_ = false;
}


//-----------------------------------------------------------------------------
/*! 16 ビットのマーカを得る
 * @exception  UnexpectedEOFException  ファイルが終端に達したとき
 */
unsigned long Decoder::get_marker() {
  while (stream_.show8() != 0xFF) {
    stream_.read1();
  }

  return stream_.read16();
}



//-----------------------------------------------------------------------------
/*!  スキャンの直前まで読み込み
 */
void Decoder::require_segments() {

  while ( ! image_started_ ) {
    unsigned int mk;
    switch( mk = get_marker() ) {
    case END_OF_IMAGE:
      image_end_= true;
      break;

    case START_OF_IMAGE :
      segment_startOfImage();
      break;

    case APPLICATION_0:
      segment_application0();
      break;

    case DEFINE_QUANTIZATION_TABLE:
      segment_defineQuantizationTable();
      break;

    case START_OF_FRAME_0:
      segment_startOfFrame0();
      break;

    case DEFINE_HUFFMAN_TABLE:
      segment_defineHuffmanTable();
      break;

    case START_OF_SCAN:
      segment_startOfScan();
      image_started_ = true;
      break;

    default:
      mk &= 0xff;
      const char *d = getMarkerDescription(mk);
      if (d) {
        printf("0x%02x .. %s\n", mk, d );
      } else {
        printf("0x%02x .. unknown marker\n", mk);
      }

      if ( 0xf0 <= mk && mk <= 0xfd) {
        static char msg[128];
        sprintf(msg, "unignoreable but unknown marker 0x%2x", mk);
        throw msg;
      }

      segment_unknown();
      break;
    }

  }

  if (! start_of_frame_ )
    throw "[start of frame 0] not found.";

  return;
}


//-----------------------------------------------------------------------------
//! Start Of Image
void Decoder::segment_startOfImage() {
  return; // 中身は空
}



//-----------------------------------------------------------------------------
/*! Application 0
 * @exception UnexpectedEOFException
 */
void Decoder::segment_application0() {

  // 2 バイト .. セグメント長
  unsigned long len = stream_.read16() - 2;

  // 5 バイト .. 識別子
  for (int i=0; i<5; ++i)
    app0_identifier_[i] = (char)stream_.read8();
  len -= 5;

  // 識別子は "JFIF\0" ?
  if ( !strcmp(app0_identifier_, "JFIF") ) {

    // 1 バイト .. メジャーバージョン番号
    major_revision_ = stream_.read8();
    // 1 バイト .. マイナーバージョン番号
    minor_revision_ = stream_.read8();

    // 1 バイト .. 密度単位
    density_unit_ = stream_.read8();

    // 2 バイト .. 横密度
    horiz_density_ = stream_.read16();
    // 2 バイト .. 縦密度
    verti_density_ = stream_.read16();

    // 1 バイト .. サムネイル横幅
    thumb_width_ = stream_.read8();
    // 1 バイト .. サムネイル高さ
    thumb_height_ = stream_.read8();

    len -= 9;

    // 残りはサムネイル画像
    while (len > 0) {
      len--;
      stream_.read8();
    }

  } else if ( !strcmp(app0_identifier_, "JFXX") ) {
    while (len--)
      stream_.read8();
    // unsupported
  } else {
    throw "unsupported application-0 segment";
  }

  return;
}



//-----------------------------------------------------------------------------
//  Define Quantization Table  量子化テーブル定義
void Decoder::segment_defineQuantizationTable() {

  // 2 バイト .. セグメント長
  unsigned long len = stream_.read16();

  // テーブル(65バイト)が記述されている回数
  int count = (len - 2) / 65;

  while (count-- > 0) {
    unsigned long precision = stream_.read8(); // とりあえず 1 バイト取り出し
    unsigned long number = precision & 15;   //   下位 4 ビット .. 量子化テーブル識別子
    precision >>= 4;           //   上位 4 ビット .. 精度

    for (int i=0; i<64; ++i) {
      quant_tables_[number][i] = (short)stream_.read8();
    }
  }

  return;
}



//-----------------------------------------------------------------------------
//  Start Of Frame 0

void Decoder::segment_startOfFrame0 () {

  start_of_frame_ = true;

  // 2 バイト .. セグメント長
  unsigned int len = stream_.read16();

  // 1 バイト .. 画素深度 ( '8' 固定 )
  depth_per_component_ = stream_.read8();

  // 2 バイト .. 高さ
  height_ = stream_.read16();

  // 2 バイト .. 横幅
  width_ = stream_.read16();

  // 1 バイト .. 画像成分数
  component_count_ = stream_.read8();

  for (int i=0; i<component_count_; ++i) {
    // 1 バイト .. 成分識別子
    component_id_[i] = stream_.read8();

    // 1 バイト .. サンプリングファクタ
    unsigned int sampling_factor = stream_.read8();
      // そのうち上位 4 ビット .. 水平サンプリングファクタ
      horizontal_sampling_factor_[i] = (int)sampling_factor >> 4;
      // 下位 4 ビット .. 垂直サンプリングファクタ
      vertical_sampling_factor_[i] = (int)sampling_factor & 15;

    // 1 バイト.. 量子化テーブルセレクタ
    qt_selector_[i] = stream_.read8();
  }


  // --- 後の利便性のために，あらかじめ色々計算しておくとよい ---

  // 1 ドットの画素深度
  depth_ = depth_per_component_ * component_count_;
  // 1 ドットのバイト数
  bytes_per_pixel_ = depth_ / 8;

  // 実際には画像は 8 ドット単位に揃えて符号化されている
  coded_width_ = (width_ + 7) & ~7;
  coded_height_ = (height_ + 7) & ~7;

  // いくつか取得できたサンプリングファクタの中で，最大値を保持
  horizontal_sampling_factor_max_ =
            _get_max( horizontal_sampling_factor_, horizontal_sampling_factor_ + component_count_);
  vertical_sampling_factor_max_ =
            _get_max( vertical_sampling_factor_, vertical_sampling_factor_ + component_count_);

  // 一つの MCU は何 x 何ドットか
  MCU_coded_width_ = horizontal_sampling_factor_max_ * 8;
  MCU_coded_height_ = vertical_sampling_factor_max_ * 8;

  // 1ラインの保持に必要なバイト数
  int w = (width_ + MCU_coded_width_ - 1) & ~(MCU_coded_width_ - 1);
  bytes_per_line_ = bytes_per_pixel_ * w;

  // デコードした画像を少しずつためて置く場所
  delete decoded_image_;
  decoded_image_ = new unsigned char[MCU_coded_height_ * bytes_per_line_];

  // デコード時の作業領域
  for (i=0; i<component_count_; ++i) {
    delete coeff_workarea_[i];
    coeff_workarea_[i] =
      new short[64 * horizontal_sampling_factor_[i] * vertical_sampling_factor_[i]];
  }

  return;
}



//-----------------------------------------------------------------------------
// ハフマンテーブルをハフマン木に変換する

static inline
int make_huffman_tree( HuffmanTree *root,
                 int *src_codes, int *src_bits, int table_size )
{
  int code, bits;
  int i, b;
  HuffmanTree *next_tree, *current_tree;

  // 木を生成してゆく
  for (i=0; i<table_size; ++i) {
    code = src_codes[i];
    bits = src_bits[i];
    current_tree = root;

    while (bits--) {

      if (bits == 0) {
        current_tree->value[code & 1] = i;
        break;
      }

      b = (code & (1 << bits)) ? 1 : 0;
      next_tree = current_tree->next_tree[b];

      if (! next_tree) {
        // 新しく枝を生成する
        next_tree = new HuffmanTree();
        current_tree->next_tree[b] = next_tree;
      }

      current_tree = next_tree;
    }
  }

  return 0; // noerror
}

//-----------------------------------------------------------------------------
//   Define Huffman Table

void Decoder::segment_defineHuffmanTable() {

  // 2 バイト .. セグメント長
  int len = (int)stream_.read16() - 2;

  while (len > 0) {
    int i, k, lengths[16], code = 0;
    int codes[256] = {0}, bits[256] = {0};

    // 1 バイト .. テーブルのクラスと識別子
    unsigned int class_and_id = stream_.read8();
      unsigned int table_class = class_and_id >> 4;
      unsigned int table_id    = class_and_id & 15;
    --len;

    // 1 バイト x 16 回 .. 後ろに格納されているそれぞれの要素の個数
    for (i=0; i<16; ++i)
      lengths[i] = (int)stream_.read8();
    len -= 16;

    for (i=0; i<16; ++i) {

      // 1 バイト x lengths[i] 回 .. 要素
      for (k=0; k<lengths[i]; ++k) {
        unsigned int value = stream_.read8();
        codes[value] = code++;
        bits[value] = i+1;

      }

      code <<= 1;
      len -= lengths[i];
    }

    // テーブルを木に変換
    HuffmanTree  *tree  = &huffman_trees_[table_class][table_id];
    make_huffman_tree( tree, codes, bits, 256);
  }

  return;
}

//-----------------------------------------------------------------------------
void Decoder::segment_startOfScan() {

  // 2 バイト .. セグメント長
  int len = (int)stream_.read16() - 2;

  // 1 バイト .. スキャン内の成分数
  int count = (int)stream_.read8();

  for (int i=0; i<count; ++i) {
    // 1 バイト .. スキャン成分セレクタ
    int component_id = (int)stream_.read8();

    // とりあえず 1 バイト
    unsigned int dc_ac_table = stream_.read8();
      // 上位 4 ビット .. DC 成分用ハフマンテーブルセレクタ
      huffman_table_for_dc_[component_id - 1] = dc_ac_table >> 4;
      // 下位 4 ビット .. AC 成分用ハフマンテーブルセレクタ
      huffman_table_for_ac_[component_id - 1] = dc_ac_table & 15;
  }

  // 以下 3 バイトはここでは未使用
  unsigned int unused;
  unused = stream_.read8();
  unused = stream_.read8();
  unused = stream_.read8();

  return;
}



//-----------------------------------------------------------------------------
void Decoder::segment_unknown(){
  int len = (int)stream_.read16() - 2;

  while (len-- > 0) {
    stream_.read8();
  }

  return;
}


//=============================================================================
int Decoder::getWidth() {
  require_segments();
  return width_;
}

//=============================================================================
int Decoder::getHeight() {
  require_segments();
  return height_;
}


//=============================================================================
int Decoder::getDepth() {
  require_segments();
  return depth_;
}

//=============================================================================
int Decoder::getBPL() {
  require_segments();
  return bytes_per_line_;
}



//=============================================================================
/*! 1 ラインを得る
 * @param buf 少なくとも getBPL() メソッドで取得されるバイト数を必ず確保している
 * @return 得られたラインの y 座標
 */
int Decoder::getLine(void *buf) {
  require_segments();

  int l = getting_line_ & (MCU_coded_height_ - 1);  // getting_line_ % MCU_cocded_height_

  if ( l == 0 )
    decode_MCU_line(); // decoded_buffer_ に MCU_coded_height_ * width_ の範囲をデコード

  _copy(buf, bytes_per_line_, decoded_image_ + l * bytes_per_line_ );

  ++getting_line_;

  return (getting_line_ < height_);
}



//-----------------------------------------------------------------------------
/*! MCU をデコードする
 */
void Decoder::decode_MCU_line() {

  if (getting_line_ & (MCU_coded_height_ - 1))
    throw "assertion failed!";

  for (int x = 0;   x < width_;   x += MCU_coded_width_) {

    for (int cid=0; cid < component_count_; ++cid) {
      short *_coeff = coeff_workarea_[cid];

      for (int h=0; h < horizontal_sampling_factor_[cid]; ++h) {
        for (int v=0; v < vertical_sampling_factor_[cid]; ++v) {

          decode_block(_coeff, cid); // 8x8 ブロックをデコ−ド

          dequantize(_coeff, cid); // 逆量子化

          idct(_coeff); // 逆離散コサイン変換

          _coeff += 64;
        }
      }
    }

    // coeff_workarea_ に展開できた MCU を decoded_image_ に(必要なら拡大して)転送
    block_to_image(x);
  }

}



//-----------------------------------------------------------------------------
/*! ビットを得る
 * スキャン内は 0xff 0x00 の並びを 0xffと扱う
 */
m_ui32 Decoder::read_bits_scan(unsigned int bits) {
  m_ui32 result = 0;

  while (bits-- > 0) {
    if ( stream_.isAligned() ) {
      if ( skip_next_0x00_ ) {
        skip_next_0x00_ = false;
        if ( stream_.show8() == 0x00 )
          stream_.read8();
        else
          throw "unknown code";
      }
      if ( stream_.show8() == 0xff )
        skip_next_0x00_ = true;
    }

    result = (result << 1) | stream_.read1();
  }

  return result;
}


//-----------------------------------------------------------------------------
m_ui32 Decoder::decode_huffman_tree( HuffmanTree *tree ) {
  m_ui32 b;
  HuffmanTree *next;

  while (1) {
    b = read_bits_scan(1);
    next = tree->next_tree[b];
    if (!next) {
      if (tree->value[b] == HuffmanTree::ERROR_VALUE )
        throw "huffman code exception";

      return tree->value[b];
    }
    tree = next;
  }

  return 999; // unreachable
}

//-----------------------------------------------------------------------------
/*! 8x8 ブロックをデコードする
 */
void Decoder::decode_block(short *coeff, int cid) {
  m_ui32 bits;
  HuffmanTree *tree_dc = &huffman_trees_[0][huffman_table_for_dc_[cid]];
  HuffmanTree *tree_ac = &huffman_trees_[1][huffman_table_for_ac_[cid]];

  _bezero(coeff, sizeof(short) * 64);

  // DC 成分をデコード
  bits = decode_huffman_tree( tree_dc );

  if (bits) {
    short dc_diff = (short)read_bits_scan(bits);
    short sign = 1 << (bits - 1);
    if ( (dc_diff & sign) == 0 ) dc_diff -= (1 << bits) - 1;

    dc_pred_[cid] += dc_diff;
  }

  coeff[0] = dc_pred_[cid];


  // AC 成分をデコード
  unsigned int p = 1;
  while ( bits = decode_huffman_tree(tree_ac) ) {

    // 上位ビット(下位4ビットを除いたもの) .. ゼロの個数
    unsigned int zerorun = bits >> 4;
    // 下位 4 ビット .. ゼロでない係数のビット長
    bits &= 0x0f;

    if (bits == 0) {
      if (zerorun != 15)
        throw "whats happen!?";

      p += 16;
      if (63 < p)
        throw "out of bounds.";

    } else {
      p += zerorun;
      if (63 < p)
        throw "out of bounds.";

      short ac = (short)read_bits_scan( bits );
      short sign = 1 << (bits - 1);
      if ( (ac & sign) == 0 ) ac -= (1 << bits) - 1;

      coeff[zigzag[p++]] = ac;
    }

    // ここで添え字が 63 を超えたら終了
    if (63 < p)
      break;
  }
}


//-----------------------------------------------------------------------------
//! 逆量子化
void Decoder::dequantize( short *coeff, int cid ) {
  int i=64;
  short *quant_table = quant_tables_[ qt_selector_[cid] ];
  while (i--)
    coeff[i] *= quant_table[i];
}


//-----------------------------------------------------------------------------
/*! coeff_workarea_ にデコードされているはずの MCU を decoded_image_ に転送
 * @param x       画像の x 座標
 */
void Decoder::block_to_image( int xpoint ) {

  int cid, ix, iy, cx[3], cy[3]={0}, x_ratio[3], y_ratio[3];

#define V(_n)                             \
  *(                                      \
    (coeff_workarea_[_n]) +               \
    ( 64 * horizontal_sampling_factor_[_n] * (cy[_n] / 8) ) +  \
    ( 64 * (cx[_n] / 8) ) +               \
    (8 * (cy[_n] & 7) + (cx[_n] & 7))     \
  )

#define CLIP(_v) ((unsigned char)(((_v) < 0) ? 0 : (255 < (_v) ? 255 : (_v))))

  for(cid=0; cid<component_count_; ++cid) {
    cy[cid] = 0;
    y_ratio[cid] = 0;
  }

  for (iy=0; iy< MCU_coded_height_; ++iy) {

    for(cid=0; cid<component_count_; ++cid) {
      cx[cid] = 0;
      x_ratio[cid] = 0;
    }

    for (ix=0; ix< MCU_coded_width_; ++ix) {
      int Y  = (int)V(0) + 128;
      int Cb = (int)V(1);
      int Cr = (int)V(2);

      unsigned char *pix = decoded_image_ +
                         (bytes_per_line_ * iy) +
                         (bytes_per_pixel_ * (ix + xpoint));
      pix[0] = CLIP(Y +  17720 * Cb / 10000);
      pix[1] = CLIP(Y - ( 3441 * Cb + 7141 * Cr) / 10000);
      pix[2] = CLIP(Y +  14020 * Cr / 10000 );

      for(cid=0; cid<component_count_; ++cid) {
        x_ratio[cid] += horizontal_sampling_factor_[cid];
        if (horizontal_sampling_factor_max_ <= x_ratio[cid]) {
          x_ratio[cid] -= horizontal_sampling_factor_max_;
          ++cx[cid];
        }
      }

    }

    for(cid=0; cid<component_count_; ++cid) {
      y_ratio[cid] += vertical_sampling_factor_[cid];
      if (vertical_sampling_factor_max_ <= y_ratio[cid]) {
        y_ratio[cid] -= vertical_sampling_factor_max_;
        ++cy[cid];
      }
    }

  }

#undef V
#undef CLIP
}



