///////////////////////////////////////////////////////////////////////////////
//
//   jpeg をデコードする
//

#include <stdlib.h>
#include <string.h>

#include "jdecoder.h"


static unsigned int zigzag[] = {
  0,  1,  8, 16,  9,  2,  3, 10,
 17, 24, 32, 25, 18, 11,  4,  5,
 12, 19, 26, 33, 40, 48, 41, 34,
 27, 20, 13,  6,  7, 14, 21, 28,
 35, 42, 49, 56, 57, 50, 43, 36,
 29, 22, 15, 23, 30, 37, 44, 51,
 58, 59, 52, 45, 38, 31, 39, 46,
 53, 60, 61, 54, 47, 55, 62, 63
};



//=============================================================================
/*! コンストラクタ  InputStream 型を引数に
 */
JPEGDecoder::JPEGDecoder( InputStream *stream ) :
  stream_(stream)
{
  image_started_ = false;
  image_end_ = false;
  start_of_frame_ = false;

  decoded_image_ = 0;
  coeff_workarea_[0] =
  coeff_workarea_[1] =
  coeff_workarea_[2] = 0;

  getting_line_ = 0;
  dc_pred_[0] =
  dc_pred_[1] =
  dc_pred_[2] = 0;
  skip_next_0x00_ = false;
}


//-----------------------------------------------------------------------------
/*! 16 ビットのマーカを得る
 * @exception  UnexpectedEOFException  ファイルが終端に達したとき
 */
unsigned long JPEGDecoder::get_marker() {
  while (stream_.show8() != 0xFF) {
    stream_.read1();
  }

  return stream_.read16();
}



//-----------------------------------------------------------------------------
/*!  スキャンの直前まで読み込み
 */
void JPEGDecoder::require_segments() {

  while ( ! image_started_ ) {
    unsigned int mk;
    switch( mk = get_marker() ) {
    case END_OF_IMAGE:
      image_end_= true;
      break;

    case START_OF_IMAGE :
      segment_startOfImage();
      break;

    case APPLICATION_0:
      segment_application0();
      break;

    case DEFINE_QUANTIZATION_TABLE:
      segment_defineQuantizationTable();
      break;

    case START_OF_FRAME_0:
      segment_startOfFrame0();
      break;

    case DEFINE_HUFFMAN_TABLE:
      segment_defineHuffmanTable();
      break;

    case START_OF_SCAN:
      segment_startOfScan();
      image_started_ = true;
      break;

    default:
      printf("unknown marker %2x\n", mk );
      segment_unknown();
      break;
    }

  }

  if (! start_of_frame_ )
    throw "[start of frame 0] not found.";

  return;
}


//-----------------------------------------------------------------------------
//! Start Of Image
void JPEGDecoder::segment_startOfImage() {
  return; // 中身は空
}



//-----------------------------------------------------------------------------
/*! Application 0
 * @exception UnexpectedEOFException
 */
void JPEGDecoder::segment_application0() {

  // 2 バイト .. セグメント長
  unsigned long len = stream_.read16();

  // 5 バイト .. 識別子
  for (int i=0; i<5; ++i)
    app0_identifier_[i] = (char)stream_.read8();

  // 識別子は "JFIF\0" ?
  if ( !strcmp(app0_identifier_, "JFIF") ) {

    // 1 バイト .. メジャーバージョン番号
    major_revision_ = stream_.read8();
    // 1 バイト .. マイナーバージョン番号
    minor_revision_ = stream_.read8();

    // 1 バイト .. 密度単位
    density_unit_ = stream_.read8();

    // 2 バイト .. 横密度
    horiz_density_ = stream_.read16();
    // 2 バイト .. 縦密度
    verti_density_ = stream_.read16();

    // 1 バイト .. サムネイル横幅
    thumb_width_ = stream_.read8();
    // 1 バイト .. サムネイル高さ
    thumb_height_ = stream_.read8();

    // 残り (len - 16) バイトはサムネイル画像
    len -= 16;
    while (len > 0) {
      len--;
      stream_.read8();
    }

  } else if ( !strcmp(app0_identifier_, "JFXX") ) {
    // unsupported
  }

  return;
}



//-----------------------------------------------------------------------------
//  Define Quantization Table  量子化テーブル定義
void JPEGDecoder::segment_defineQuantizationTable() {

  // 2 バイト .. セグメント長
  unsigned long len = stream_.read16();

  // テーブル(65バイト)が記述されている回数
  int count = (len - 2) / 65;

  while (count-- > 0) {
    unsigned long precision = stream_.read8(); // とりあえず 1 バイト取り出し
    unsigned long number = precision & 15;   //   下位 4 ビット .. 量子化テーブル識別子
    precision >>= 4;           //   上位 4 ビット .. 精度

    for (int i=0; i<64; ++i) {
      quant_tables_[number][i] = (short)stream_.read8();
    }
  }

  return;
}



static inline
int _get_max(int *begin, int *over) {
  int max = 0;
  while (begin < over) {
    if ( max < *begin ) max = *begin;
    ++begin;
  }
  return max;
}

static inline
int _get_min(int *begin, int *over) {
  int min = 0x7fffffff;
  while (begin < over) {
    if (*begin < min) min = *begin;
    ++begin;
  }
  return min;
}

//-----------------------------------------------------------------------------
//  Start Of Frame 0

void JPEGDecoder::segment_startOfFrame0 () {

  start_of_frame_ = true;

  // 2 バイト .. セグメント長
  unsigned int len = stream_.read16();

  // 1 バイト .. 画素深度 ( '8' 固定 )
  depth_per_component_ = stream_.read8();

  // 2 バイト .. 高さ
  height_ = stream_.read16();

  // 2 バイト .. 横幅
  width_ = stream_.read16();

  // 1 バイト .. 画像成分数
  component_count_ = stream_.read8();

  for (int i=0; i<component_count_; ++i) {
    // 1 バイト .. 成分識別子
    component_id_[i] = stream_.read8();

    // 1 バイト .. サンプリングファクタ
    unsigned int sampling_factor = stream_.read8();
      // そのうち上位 4 ビット .. 水平サンプリングファクタ
      horizontal_sampling_factor_[i] = (int)sampling_factor >> 4;
      // 下位 4 ビット .. 垂直サンプリングファクタ
      vertical_sampling_factor_[i] = (int)sampling_factor & 15;

printf("horiz sampling factor %d: %d\n", i, horizontal_sampling_factor_[i]);
printf("verti sampling factor %d: %d\n", i, vertical_sampling_factor_[i]);

    // 1 バイト.. 量子化テーブルセレクタ
    qt_selector_[i] = stream_.read8();
  }


  // --- 後の利便性のために，あらかじめ色々計算しておくとよい ---

  // 1 ドットの画素深度
  depth_ = depth_per_component_ * component_count_;
  // 1 ドットのバイト数
  bytes_per_pixel_ = depth_ / 8;

  // 実際には画像は 8 ドット単位に揃えて符号化されている
  coded_width_ = (width_ + 7) & ~7;
  coded_height_ = (height_ + 7) & ~7;

  // いくつか取得できたサンプリングファクタの中で，最大値を保持
  horizontal_sampling_factor_max_ =
            _get_max( horizontal_sampling_factor_, horizontal_sampling_factor_ + component_count_);
  vertical_sampling_factor_max_ =
            _get_max( vertical_sampling_factor_, vertical_sampling_factor_ + component_count_);

  // 一つの MCU は何 x 何ドットか
  MCU_coded_width_ = horizontal_sampling_factor_max_ * 8;
  MCU_coded_height_ = vertical_sampling_factor_max_ * 8;

  // 1ラインの保持に必要なバイト数
  int w = (width_ + MCU_coded_width_ - 1) & ~(MCU_coded_width_ - 1);
  bytes_per_line_ = bytes_per_pixel_ * w;

  // デコードした画像を少しずつためて置く場所
  delete decoded_image_;
  decoded_image_ = new unsigned char[MCU_coded_height_ * bytes_per_line_];

  // デコード時の作業領域
  for (i=0; i<component_count_; ++i) {
    delete coeff_workarea_[i];
    coeff_workarea_[i] =
      new short[64 * horizontal_sampling_factor_[i] * vertical_sampling_factor_[i]];
  }


  return;
}



//-----------------------------------------------------------------------------
// ハフマンテーブルをハフマン木に変換する

static inline
int make_huffman_tree( HuffmanTree *root,
                 int *src_codes, int *src_bits, int table_size )
{
  int code, bits;
  int i, b;
  HuffmanTree *next_tree, *current_tree;

  // 木を生成してゆく
  for (i=0; i<table_size; ++i) {
    code = src_codes[i];
    bits = src_bits[i];
    current_tree = root;

    while (bits--) {

      if (bits == 0) {
        current_tree->value[code & 1] = i;
        break;
      }

      b = (code & (1 << bits)) ? 1 : 0;
      next_tree = current_tree->next_tree[b];

      if (! next_tree) {
        // 新しく枝を生成する
        next_tree = new HuffmanTree();
        current_tree->next_tree[b] = next_tree;
      }

      current_tree = next_tree;
    }
  }

  return 0; // noerror
}

//-----------------------------------------------------------------------------
//   Define Huffman Table

static
void print_bits( unsigned int code, unsigned int bits ) {
  if ( 1 < bits )
    print_bits( code >> 1, bits - 1);
  printf("%c", code & 1 ? '1' : '0');
}

/*static
void print_component( HuffmanTable *table, int comp ) {
  printf(" %3d (zrl %2d, code %2d): ", comp, comp >> 4, comp & 15 );
  print_bits( table[comp].code, table[comp].bits );
  printf("\n");
}*/


void JPEGDecoder::segment_defineHuffmanTable() {

  // 2 バイト .. セグメント長
  int len = (int)stream_.read16() - 2;

  while (len > 0) {
    int i, k, lengths[16], code = 0;
    int codes[256] = {0}, bits[256] = {0};

    // 1 バイト .. テーブルのクラスと識別子
    unsigned int class_and_id = stream_.read8();
      unsigned int table_class = class_and_id >> 4;
      unsigned int table_id    = class_and_id & 15;
    --len;

    // どの木に値を記憶させるかを選択
    HuffmanTree  *tree  = &huffman_trees_[table_class][table_id];

//printf("table for %s-%d\n", table_class ? "ac" : "dc", table_id);

    // 1 バイト x 16 回 .. 後ろに格納されているそれぞれの要素の個数
    for (i=0; i<16; ++i)
      lengths[i] = (int)stream_.read8();
    len -= 16;

    for (i=0; i<16; ++i) {

      // 1 バイト x lengths[i] 回 .. 要素
      for (k=0; k<lengths[i]; ++k) {
        unsigned int value = stream_.read8();
        codes[value] = code++;
        bits[value] = i+1;

//printf("%3d:", value );
//print_bits( codes[value], bits[value] );
//printf("\n");
      }

      code <<= 1;
      len -= lengths[i];
    }

    // テーブルを木に変換
    make_huffman_tree( tree, codes, bits, 256);
  }

  return;
}

//-----------------------------------------------------------------------------
void JPEGDecoder::segment_startOfScan() {

  // 2 バイト .. セグメント長
  int len = (int)stream_.read16() - 2;

  // 1 バイト .. スキャン内の成分数
  int count = (int)stream_.read8();

  for (int i=0; i<count; ++i) {
    // 1 バイト .. スキャン成分セレクタ
    int component_id = (int)stream_.read8();

    // とりあえず 1 バイト
    unsigned int dc_ac_table = stream_.read8();
      // 上位 4 ビット .. DC 成分用ハフマンテーブルセレクタ
      huffman_table_for_dc_[i] = dc_ac_table >> 4;
      // 下位 4 ビット .. AC 成分用ハフマンテーブルセレクタ
      huffman_table_for_ac_[i] = dc_ac_table & 15;
  }

  // 以下 3 バイトはここでは未使用
  unsigned int unused;
  unused = stream_.read8();
  unused = stream_.read8();
  unused = stream_.read8();

  return;
}



//-----------------------------------------------------------------------------
void JPEGDecoder::segment_unknown(){
  int len = (int)stream_.read16() - 2;

  while (len-- > 0) {
    stream_.read8();
  }

  return;
}


//=============================================================================
int JPEGDecoder::getWidth() {
  require_segments();
  return width_;
}

//=============================================================================
int JPEGDecoder::getHeight() {
  require_segments();
  return height_;
}


//=============================================================================
int JPEGDecoder::getDepth() {
  require_segments();
  return depth_;
}

//=============================================================================
int JPEGDecoder::getBPL() {
  require_segments();
  return bytes_per_line_;
}



//-----------------------------------------------------------------------------
/*! バッファをコピーする
 */
static inline
void _copy(void *dst, unsigned int dsize, const void *src ) {

  unsigned int size = dsize;
  unsigned long *ld = (unsigned long *)dst;
  unsigned long *ls = (unsigned long *)src;

  while ( sizeof(unsigned long) <= size ) {
    *ld++ = *ls++;
    size -= sizeof(unsigned long);
  }

  unsigned char *cd = (unsigned char *)ld;
  unsigned char *cs = (unsigned char *)ls;

  while (size--) {
    *cd++ = *cs++;
  }

}


//=============================================================================
/*! 1 ラインを得る
 * @param buf 少なくとも getBPL() メソッドで取得されるバイト数を必ず確保している
 * @return 得られたラインの y 座標
 */
int JPEGDecoder::getLine(void *buf) {
  require_segments();

  int l = getting_line_ & (MCU_coded_height_ - 1);  // getting_line_ % MCU_cocded_height_

  if ( l == 0 )
    decode_MCU_line(); // decoded_buffer_ に MCU_coded_height_ * width_ の範囲をデコード

  _copy(buf, bytes_per_line_, decoded_image_ + l * bytes_per_line_ );

  ++getting_line_;

  return (getting_line_ < height_);
}



//-----------------------------------------------------------------------------
/*! MCU をデコードする
 */
void JPEGDecoder::decode_MCU_line() {
  int cid, vcount[3], hcount[3];
  int height;

  if (getting_line_ & (MCU_coded_height_ - 1))
    throw "assertion failed!";

#define _MIN(_a, _b) (_a) < (_b) ? (_a) : (_b)

  // --- 1 つの MCU を構成する 8x8 ブロックの縦の個数
  // 画像の下部は MCU を構成するブロックが少ないかも
  if (height_ < (getting_line_ + MCU_coded_height_)) {
    int count = (height_ - getting_line_ + 7) / 8;
    for (cid=0; cid<component_count_; ++cid)
      vcount[cid] = _MIN(count, vertical_sampling_factor_[cid]);
    height = count * 8;
  }
  // 通常はこっち
  else {
    for (cid=0; cid<component_count_; ++cid )
      vcount[cid] = vertical_sampling_factor_[cid];
    height = MCU_coded_height_;
  }


  // --- 1 つの MCU を構成する 8x8 ブロックの横の個数
  for (cid=0; cid<component_count_; ++cid) {
    hcount[cid] = horizontal_sampling_factor_[cid];
  }


  for (int x = 0;   x < width_;   x += MCU_coded_width_) {

    // 画像の右端は MCU を構成するブロックが少ないかも
    if (width_ < (x + MCU_coded_width_)) {
      int count = (width_ - x + 7) / 8; // 横方向にあといくつの 8x8 ブロックが存在するか？
      for (cid=0; cid<component_count_; ++cid)
        hcount[cid] = _MIN(count, horizontal_sampling_factor_[cid]);
    }

    for (cid=0; cid < component_count_; ++cid) {
      short *_coeff = coeff_workarea_[cid];

      for (int h=0; h < hcount[cid]; ++h) {
        for (int v=0; v < vcount[cid]; ++v) {

          decode_block(_coeff, cid); // 8x8 ブロックをデコ−ド

          dequantize(_coeff, cid); // 逆量子化

          idct(_coeff); // 逆離散コサイン変換

          _coeff += 64;
        }
      }
    }

    // coeff_workarea_ に展開できた MCU を decoded_image_ に(必要なら拡大して)転送
    block_to_image(x, height, hcount, vcount);
  }
#undef _MIN
}



//-----------------------------------------------------------------------------
/*! ビットを得る
 * スキャン内は 0xff 0x00 の並びを 0xffと扱う
 */
m_ui32 JPEGDecoder::read_bits_scan(unsigned int bits) {
  m_ui32 result = 0;

  while (bits-- > 0) {
    if ( stream_.isAligned() ) {
      if ( skip_next_0x00_ ) {
        skip_next_0x00_ = false;
        if ( stream_.show8() == 0x00 )
          stream_.read8();
      } else if ( stream_.show8() == 0xff )
        skip_next_0x00_ = true;
    }

    result = (result << 1) | stream_.read1();
  }

  return result;
}


//-----------------------------------------------------------------------------
m_ui32 JPEGDecoder::decode_huffman_tree( HuffmanTree *tree ) {
  m_ui32 b;
  HuffmanTree *next;

  while (1) {
    b = read_bits_scan(1);
    next = tree->next_tree[b];
    if (!next) {
      return tree->value[b];
    }
    tree = next;
  }

  return 999; // unreachable
}

//-----------------------------------------------------------------------------
/*! 8x8 ブロックをデコードする
 */
void JPEGDecoder::decode_block(short *coeff, int cid) {
  m_ui32 bits;
  HuffmanTree *tree_dc = &huffman_trees_[0][huffman_table_for_dc_[cid]];
  HuffmanTree *tree_ac = &huffman_trees_[1][huffman_table_for_ac_[cid]];


  // DC 成分をデコード
  bits = decode_huffman_tree( tree_dc );
  if (bits > 16) {
    // エラー
    throw "huffman code exception";
  }

  if (bits) {
    short dc_diff = (short)read_bits_scan(bits);
    short sign = 1 << (bits - 1);
    if ( (dc_diff & sign) == 0 ) dc_diff -= (1 << bits) - 1;

    dc_pred_[cid] += dc_diff;
  }

  coeff[0] = dc_pred_[cid];


  // AC 成分をデコード
  unsigned int p = 1;
  while ( p < 64 ) {
    bits = decode_huffman_tree( tree_ac );
    if (bits >= 999) {
      // エラー
      throw "huffman code exception";
    }

    // end of block .. 残り全部ゼロ
    if (bits == 0) {
      while ( p < 64 )
        coeff[zigzag[p++]] = 0;
      break;
    }

    // 上位ビット(下位4ビットを除いたもの) .. ゼロの個数
    unsigned int zerorun = bits >> 4;
    while (zerorun--) {
      if (63 < p) { throw "out of bounds. zerorun processing..."; }
      coeff[zigzag[p++]] = 0;
    }

    // 下位 4 ビット .. ゼロでない係数のビット長
    bits &= 15;
    if (bits) {
      if (63 < p) { throw "out of bounds. ac reading..."; }
      short ac = (short)read_bits_scan( bits );
      short sign = 1 << (bits - 1);
      if ( (ac & sign) == 0 ) ac -= (1 << bits) - 1;

      coeff[zigzag[p++]] = ac;
    }
  }
}


//-----------------------------------------------------------------------------
//! 逆量子化
void JPEGDecoder::dequantize( short *coeff, int cid ) {
  int i=64;
  short *quant_table = quant_tables_[ (cid != 0) ? 1 : 0];
  while (i--)
    coeff[i] *= quant_table[i];
}


//-----------------------------------------------------------------------------
/*! coeff_workarea_ にデコードされているはずの MCU を decoded_image_ に転送
 * @param x       画像の x 座標
 * @param hcount  MCU を構成する 8x8 ブロックの横個数 ( component_count_ 個の配列 )
 * @param vcount  MCU を構成する 8x8 ブロックの縦個数 ( component_count_ 個の配列 )
 */
void JPEGDecoder::block_to_image( int xpoint, int height, int *hcount, int *vcount ) {

  int cid, ix, iy, cx[3], cy[3]={0}, x_ratio[3], y_ratio[3];

#define V(_n)                             \
  *(                                      \
    (coeff_workarea_[_n]) +               \
    ( 64 * hcount[_n] * (cy[_n] / 8) ) +  \
    ( 64 * (cx[_n] / 8) ) +               \
    (8 * (cy[_n] & 7) + (cx[_n] & 7))     \
  )

#define CLIP(_v) ((unsigned char)(((_v) < 0) ? 0 : (255 < (_v) ? 255 : (_v))))

  for(cid=0; cid<component_count_; ++cid) {
    cy[cid] = 0;
    y_ratio[cid] = 0;
  }

  for (iy=0; iy< height; ++iy) {

    for(cid=0; cid<component_count_; ++cid) {
      cx[cid] = 0;
      x_ratio[cid] = 0;
    }

    for (ix=0; ix< MCU_coded_width_; ++ix) {
      int Y  = (int)V(0) + 127;
      int Cb = (int)V(1);
      int Cr = (int)V(2);

      unsigned char *pix = decoded_image_ +
                         (bytes_per_line_ * iy) +
                         (bytes_per_pixel_ * (ix + xpoint));
      //pix[0] = pix[1] = pix[2] = CLIP(Y);
      pix[0] = CLIP(Y +  17720 * Cb / 10000);
      pix[1] = CLIP(Y - ( 3441 * Cb + 7141 * Cr) / 10000);
      pix[2] = CLIP(Y +  14020 * Cr / 10000 );

      for(cid=0; cid<component_count_; ++cid) {
        x_ratio[cid] += horizontal_sampling_factor_[cid];
        if (horizontal_sampling_factor_max_ <= x_ratio[cid]) {
          x_ratio[cid] -= horizontal_sampling_factor_max_;
          ++cx[cid];
        }
      }

    }

    for(cid=0; cid<component_count_; ++cid) {
      y_ratio[cid] += vertical_sampling_factor_[cid];
      if (vertical_sampling_factor_max_ <= y_ratio[cid]) {
        y_ratio[cid] -= vertical_sampling_factor_max_;
        ++cy[cid];
      }
    }

  }

#undef V
#undef CLIP
}



