/*------------------------------------------------------------------------
 *
 * Copyright (c) 1997-1998 by Cornell University.
 * 
 * See the file "license.txt" for information on usage and redistribution
 * of this file, and for a DISCLAIMER OF ALL WARRANTIES.
 *
 * mpgselect.c
 *
 * Dan Rabinovitz dsr@cs.cornell.edu
 *
 * Usage : mpgselect inputMPG startFrame endFrame
 *
 * Make two pass on the MPEG file, first pass create an index, second pass
 * uses the index to decode gray scale version of the selected frames 
 * (startFrame .. endFrame) from the MPEG.
 *
 *------------------------------------------------------------------------
 */
#include "dvmbasic.h"
#include "dvmpnm.h"
#include "dvmmpeg.h"
#include "dvmcolor.h"

void 
SwapInt (x, y)
    int *x;
    int *y;
{
    int temp;
    temp = *x;
    *x = *y;
    *y = temp;
}

void 
SwapByteImage (x, y)
    ByteImage **x;
    ByteImage **y;
{
    ByteImage *temp;
    temp = *x;
    *x = *y;
    *y = temp;
}

MpegVideoIndex *
MakeMpegVideoIndex (bp)
    BitParser *bp;
{
    int size = 100;
    int firstPicInGop = 0;
    int picCount = 0;
    int done = 0;
    int past = 0;
    int future = 0;
    int code, offset, tempref, len, frameNum;
    char type;
    MpegVideoIndex *index = MpegVideoIndexNew (size);
    MpegPicHdr *ph = MpegPicHdrNew ();

    while (!done) {

	MpegAnyHdrFind (bp);
	code = MpegGetCurrStartCode (bp);
	switch (code) {
	case SEQ_START_CODE:
	    /*  Not initerested in Sequence Header, skip */
	    MpegSeqHdrSkip (bp);
	    break;
	case GOP_START_CODE:
	    /*  We reach a new GOP.  We record the frame
	     *  number at the start of the GOP */
	    firstPicInGop = picCount;
	    MpegGopHdrSkip (bp);
	    break;
	case PIC_START_CODE:
	    offset = BitParserTell (bp);
	    MpegPicHdrParse (bp, ph);
	    type = MpegPicHdrGetType (ph);
	    tempref = MpegPicHdrGetTemporalRef (ph);
	    len = MpegAnyHdrFind (bp);
	    frameNum = tempref + firstPicInGop;
	    if (frameNum > size) {
		size <<= 1;
		MpegVideoIndexResize (index, size);
	    }
	    if (type == I_FRAME) {
		SwapInt (&past, &future);
		MpegVideoIndexTableAdd (index, frameNum, offset, type, len, 0, 0);
		SwapInt (&frameNum, &future);
	    } else if (type == P_FRAME) {
		SwapInt (&past, &future);
		MpegVideoIndexTableAdd (index, frameNum, offset, type, len, 
			past - frameNum, 0);
		SwapInt (&frameNum, &future);
	    } else {
		MpegVideoIndexTableAdd (index, frameNum, offset, type, len, 
			past - frameNum, future - frameNum);
	    }
	    picCount++;
	    break;
	case SEQ_END_CODE:
	    done = 1;
	    break;
	default:
	    done = 1;
	}
    }

    BitParserSeek (bp, 0);
    return index;
}

/*
 * This procedure encode 3 byte image into a bitstream bs,
 * using bitparser bp, and output it to a tcl channel called
 * name.  Assumes that the header is already encoded in the 
 * bitstream.  (This is an improvement over the routines in pnmlib.tcl
 * since it reuse the same header and bitstream)
 */
void WritePGM (y, bs, bp, name) 
    ByteImage *y;
    BitStream *bs;
    BitParser *bp;
    char *name;
{
    FILE *chan;
    int curr;

    chan = fopen(name, "wb");
    if (chan == NULL) {
	fprintf(stderr, "unable to open %s for reading.\n", name);
	exit(1);
    }
    curr = BitParserTell(bp);
    PgmEncode(y, bp);
    BitStreamFileWrite(bs, chan, 0);
    BitParserSeek(bp, curr);
    fclose(chan);
}


int main(int argc, char *argv[])
{
    BitParser *bp;
    BitStream *bs;
    BitParser *outbp;
    BitStream *outbs;
    MpegSeqHdr *sh;
    MpegPicHdr *fh;
    ByteImage *y;
    ByteImage *prevy, *futurey;
    ScImage *scy, *scu, *scv;
    VectorImage *fwdmv, *bwdmv;
    PnmHdr *pnmhdr;
    int len, start, end, offset, currFrame, futureFrame;
    int halfw, halfh, i, w, h, seqw, seqh, picSize, remw, remh, type, size;
    char outname[100];
    MpegVideoIndex *index, *tempIndex;

    /*
     * Check arguments, open file, and initialize BitStream.
     */
    if (argc < 3) {
	fprintf(stderr, "usage : %s input startFrame endFrame\n", argv[0]);
	exit(1);
    }
    bs = BitStreamMmapReadNew(argv[1]);
    if (bs == NULL) {
	fprintf(stderr, "unable to open %s for reading.\n", argv[1]);
	exit(1);
    }
    start = atoi(argv[2]);
    end = atoi(argv[3]);
    bp = BitParserNew();
    BitParserWrap(bp, bs);

    /* 
     * Create the Video Index.
     */
    index = MakeMpegVideoIndex(bp);

    /*
     * Allocate a new sequence header, skips the initial garbage (if any)
     * in the input file, and read in the sequence header.
     */
    sh = MpegSeqHdrNew();
    MpegSeqHdrFind(bp);
    MpegSeqHdrParse(bp, sh);

    /*
     * Find the width and height of the video frames.  If the width
     * and height are not multiple of 16, round it up to the next 
     * multiple of 16.
     */
    seqw = MpegSeqHdrGetWidth(sh);
    seqh = MpegSeqHdrGetHeight(sh);
    picSize = MpegSeqHdrGetBufferSize(sh);
    remw = seqw % 16;
    remh = seqh % 16;
    if (remw != 0) {
	w = seqw + 16 - remw;
    } else {
	w = seqw;
    }
    if (remh != 0) {
	h = seqh + 16 - remh;
    } else {
	h = seqh;
    }
    halfw = w/2;
    halfh = h/2;

    /*
     * Allocates all the ByteImages and ScImages that we need.
     * y, u, v for decoded frame in YUV color space, r, g, b for
     * decoded frame in RGB color space. prevy, prevu, prevv are
     * past frames in YUV color space, futurey, futureu and futurev
     * are future frames in YUV color space.  scy, scu and scv
     * are the DCT coded images from the bitstream.  fwdmv and bwdmv
     * are the forward and backward motion vectors respectively.
     */
    y       = ByteNew (w, h);
    prevy   = ByteNew (w, h);
    futurey = ByteNew (w, h);
    fwdmv   = VectorNew (w/16, h/16);
    bwdmv   = VectorNew (w/16, h/16);
    scy     = ScNew (w/8, h/8);
    scu     = ScNew (w/16, h/16);
    scv     = ScNew (w/16, h/16);

    /*
     * Create a new PnmHdr and encode it to the BitStream.  We only do
     * this once, since all frames have the same header.
     */
    pnmhdr = PnmHdrNew();
    PnmHdrSetType(pnmhdr, PGM_BIN);
    PnmHdrSetWidth(pnmhdr, seqw);
    PnmHdrSetHeight(pnmhdr, seqh);
    PnmHdrSetMaxVal(pnmhdr, 255);
    outbs = BitStreamNew(seqw*seqh + 20);
    outbp = BitParserNew();
    BitParserWrap(outbp, outbs);
    PnmHdrEncode(pnmhdr, outbp);
    PnmHdrFree(pnmhdr);

    /*
     * Find out how many frames must we decode in order to decode
     * frame $start.  We then call mpeg_video_index_findrefs
     * to retrives index entries of frames that are needed to be
     * decoded in order to decode frame $start.  These index 
     * entries will be stored in a second mpeg_video_index called $out.
     */
    fh  = MpegPicHdrNew();
    size = MpegVideoIndexNumRefs(index, start);
    tempIndex  = MpegVideoIndexNew(size);
    MpegVideoIndexFindRefs(index, tempIndex, start);
    for (i = size-1; i >= 0; i--) {
	offset = MpegVideoIndexGetOffset(tempIndex, i);
	BitParserSeek(bp, offset);
	MpegPicHdrParse(bp, fh);
	type = MpegPicHdrGetType(fh);
	SwapByteImage(&futurey, &prevy);
	if (type == I_FRAME) {
	    MpegPicIParse(bp, sh, fh, scy, scu, scv);
	    ScIToByte(scy, y);
	} else if (type == P_FRAME) {
	    MpegPicPParse(bp, sh, fh, scy, scu, scv, fwdmv);
	    ScPToY(scy, fwdmv, prevy, y);
	} 
	SwapByteImage(&y, &futurey);
    }
    SwapByteImage(&prevy, &futurey);
    MpegVideoIndexFree(tempIndex);

    futureFrame = -1;
    len = MpegPicHdrFind(bp);
    currFrame = start;
    
    while (currFrame < end) {
	offset = MpegVideoIndexGetOffset(index, currFrame);
	BitParserSeek(bp, offset);
	MpegPicHdrParse(bp, fh);
	type = MpegPicHdrGetType(fh);
	if (type == I_FRAME) {
	    if (currFrame == futureFrame) {
		sprintf(outname, "%03di.pgm", currFrame);
		WritePGM(y, outbs, outbp, outname);
		SwapByteImage(&futurey, &prevy);
	    } else {
		MpegPicIParse(bp, sh, fh, scy, scu, scv);
		ScIToByte(scy, y);
		sprintf(outname, "%03di.pgm", currFrame);
		WritePGM(y, outbs, outbp, outname);
		SwapByteImage(&y, &prevy);
	    }
	} else if (type == P_FRAME) {
	    if (currFrame == futureFrame) {
		sprintf(outname, "%03dp.pgm", currFrame);
		WritePGM(y, outbs, outbp, outname);
		SwapByteImage(&futurey, &prevy);
	    } else {
		MpegPicPParse(bp, sh, fh, scy, scu, scv, fwdmv);
		ScPToY(scy, fwdmv, prevy, y);
		sprintf(outname, "%03dp.pgm", currFrame);
		WritePGM(y, outbs, outbp, outname);
		SwapByteImage(&y, &prevy);
	    } 
	} else {
	    futureFrame = currFrame + MpegVideoIndexGetNext(index, currFrame);
	    offset = MpegVideoIndexGetOffset(index, futureFrame);
	    BitParserSeek(bp, offset);
	    MpegPicHdrParse(bp, fh);
	    type = MpegPicHdrGetType(fh);
	    /*
	     * Decode the future frame.
	     */
	    if (type == I_FRAME) {
		MpegPicIParse(bp, sh, fh, scy, scu, scv);
		ScIToByte(scy, futurey);
	    } else if (type == P_FRAME) {
		MpegPicPParse(bp, sh, fh, scy, scu, scv, fwdmv);
		ScPToY(scy, fwdmv, prevy, futurey);
	    } 
	    offset = MpegVideoIndexGetOffset(index, currFrame);
	    BitParserSeek(bp, offset);
	    MpegPicHdrParse(bp, fh);
	    MpegPicBParse(bp, sh, fh, scy, scu, scv, fwdmv, bwdmv);
	    ScBToY(scy, fwdmv, bwdmv, prevy, futurey, y);
	    sprintf(outname, "%03db.pgm", currFrame);
	    WritePGM(y, outbs, outbp, outname);
	}
	currFrame++;
    }

    /*
     * Clean up the stuffs.
     */
    MpegPicHdrFree(fh);
    MpegSeqHdrFree(sh);
    BitStreamMmapReadFree(bs);
    BitParserFree(bp);
    ByteFree(y);
    ByteFree(prevy);
    ByteFree(futurey);
    ScFree(scy);
    VectorFree(fwdmv);
    VectorFree(bwdmv);
    MpegVideoIndexFree(index);

    return 0;
}