--[[
 * rtp_vp8_extractor.lua
 * wireshark plugin to extract VP8 stream from RTP packets
 * 
 * Based on rtp_h264_extractor.lua by Volvet Zhang
 * Adapted for VP8 RTP payload format (RFC 7741)
 * Outputs VP8 frames in IVF container format
 *
 * This plugin is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License.
 *]]

do
    local MAX_JITTER_SIZE = 50
    local vp8_data = Field.new("vp8")
    local rtp_seq = Field.new("rtp.seq")
    local rtp_timestamp = Field.new("rtp.timestamp")
	
    local function extract_vp8_from_rtp()
        local function dump_filter(fd)
            local fh = "vp8";
            if fd ~= nil and fd ~= "" then
                return string.format("%s and (%s)", fh, fd)
            else    
                return fh
            end
        end

        local vp8_tap = Listener.new("ip", dump_filter(get_filter()))
        local text_window = TextWindow.new("VP8 extractor")
        local filename = ""
        local seq_payload_table = { }
        local pass = 0
        local packet_count = 0
        local max_packet_count = 0
        local frame_buffer = nil
        local pre_seq = 0
        local frame_count = 0
        
        -- 错误计数器，避免重复弹窗
        local error_counts = {
            frame_continuation_no_start = 0,
            frame_seq_gap = 0,
            payload_mismatch = 0,
            incomplete_frame = 0
        }
		
        local function log(info)
            text_window:append(info)
            text_window:append("\n")
        end
        
        -- get_preference is only available since 3.5.0
        if get_preference then
            local fileopen_dir = get_preference("gui.fileopen.dir")
            if fileopen_dir == '' then
                log("Wireshark preference 'gui.fileopen.dir' is not set, aborting.")
                return
            end
            filename = fileopen_dir  .. "/" .. os.date("video_%Y%m%d-%H%M%S.ivf")
        else
            filename = "dump.ivf"
        end
        
        log("Dumping VP8 stream to " .. filename)
        local fp = io.open(filename, "wb")
        if fp == nil then 
            log("Failed to open dump file '" .. filename .. "'")
            return
        end
        
        -- Write IVF file header
        local function write_ivf_header()
            fp:write("DKIF")  -- signature
            fp:write(string.char(0x00, 0x00))  -- version 0
            fp:write(string.char(0x20, 0x00))  -- header length 32
            fp:write("VP80")  -- FourCC
            fp:write(string.char(0x80, 0x02))  -- width 640 (default, adjust if known)
            fp:write(string.char(0x68, 0x01))  -- height 360 (default, adjust if known)
            fp:write(string.char(0x1E, 0x00, 0x00, 0x00))  -- frame rate numerator (30)
            fp:write(string.char(0x01, 0x00, 0x00, 0x00))  -- frame rate denominator (1)
            fp:write(string.char(0x00, 0x00, 0x00, 0x00))  -- frame count (placeholder, updated later)
            fp:write(string.char(0x00, 0x00, 0x00, 0x00))  -- unused
        end
        
        write_ivf_header()
        
        local function seq_compare(left, right)  
            if math.abs(right.key - left.key) < 1000 then  
                return left.key < right.key  
            else 
                return left.key > right.key  
            end  
        end  
        
        local function write_ivf_frame(frame_data, timestamp)
            if fp == nil then
                return
            end
            local frame_size = frame_data:len()
            -- IVF frame header: 4 bytes size + 8 bytes PTS
            fp:write(string.char(
                bit.band(frame_size, 0xFF),
                bit.band(bit.rshift(frame_size, 8), 0xFF),
                bit.band(bit.rshift(frame_size, 16), 0xFF),
                bit.band(bit.rshift(frame_size, 24), 0xFF)
            ))
            -- Write timestamp (8 bytes, we use frame_count as simple PTS)
            fp:write(string.char(
                bit.band(frame_count, 0xFF),
                bit.band(bit.rshift(frame_count, 8), 0xFF),
                bit.band(bit.rshift(frame_count, 16), 0xFF),
                bit.band(bit.rshift(frame_count, 24), 0xFF),
                0x00, 0x00, 0x00, 0x00
            ))
            -- Write frame data
            fp:write(frame_data:tvb()():raw())
            fp:flush()
            frame_count = frame_count + 1
        end
        
        local function dump_complete_frame(frame_buffer)
            if frame_buffer.complete == true and #frame_buffer.payloads > 0 then
                log("Dumping complete VP8 frame with "..tostring(#frame_buffer.payloads).." packets")
                -- Concatenate all payloads
                local frame_data_str = ""
                for i, payload in ipairs(frame_buffer.payloads) do
                    frame_data_str = frame_data_str .. payload:tvb()():raw()
                end
                -- Create a ByteArray and write as IVF frame
                local ba = ByteArray.new(frame_data_str, true)
                local tvb_range = ba:tvb("VP8 Frame")
                write_ivf_frame(tvb_range:range(), frame_buffer.timestamp)
            else
                error_counts.incomplete_frame = error_counts.incomplete_frame + 1
                if error_counts.incomplete_frame == 1 then
                    log("Incomplete VP8 frame, dropped (first occurrence, further similar errors will be counted)")
                end
            end
        end
        
        -- VP8 RTP payload parsing (RFC 7741)
        local function handle_vp8_payload(seq, vp8_payload, timestamp)
            if vp8_payload:len() < 1 then
                return
            end
            
            local payload_descriptor = vp8_payload:get_index(0)
            local X_bit = bit.band(payload_descriptor, 0x80) -- Extended control bits present
            local S_bit = bit.band(payload_descriptor, 0x10) -- Start of VP8 partition
            
            local offset = 1
            
            -- Parse extended control bits if present
            if X_bit ~= 0 then
                if vp8_payload:len() < offset + 1 then return end
                local ext_byte = vp8_payload:get_index(offset)
                offset = offset + 1
                
                local I_bit = bit.band(ext_byte, 0x80)  -- PictureID present
                local L_bit = bit.band(ext_byte, 0x40)  -- TL0PICIDX present
                local T_bit = bit.band(ext_byte, 0x20)  -- TID present
                local K_bit = bit.band(ext_byte, 0x10)  -- KEYIDX present
                
                if I_bit ~= 0 then
                    if vp8_payload:len() < offset + 1 then return end
                    local pic_id = vp8_payload:get_index(offset)
                    if bit.band(pic_id, 0x80) ~= 0 then
                        offset = offset + 2  -- Extended PictureID (15 bits)
                    else
                        offset = offset + 1  -- Short PictureID (7 bits)
                    end
                end
                
                if L_bit ~= 0 then
                    offset = offset + 1
                end
                
                if T_bit ~= 0 or K_bit ~= 0 then
                    offset = offset + 1
                end
            end
            
            if offset >= vp8_payload:len() then
                return
            end
            
            -- Extract VP8 payload data (skip RTP payload descriptor)
            local vp8_frame_data = vp8_payload:tvb()(offset)
            
            if S_bit ~= 0 then
                -- Start of new frame
                if frame_buffer ~= nil and frame_buffer.complete then
                    dump_complete_frame(frame_buffer)
                end
                
                frame_buffer = {
                    payloads = {},
                    seq = seq,
                    complete = true,
                    timestamp = timestamp
                }
                table.insert(frame_buffer.payloads, vp8_frame_data)
                log("VP8 frame start: seq = "..tostring(seq))
            else
                -- Continuation of current frame
                if frame_buffer == nil then
                    error_counts.frame_continuation_no_start = error_counts.frame_continuation_no_start + 1
                    if error_counts.frame_continuation_no_start == 1 then
                        log("VP8 frame continuation without start, dropped (first occurrence)")
                    end
                    return
                end
                
                if seq ~= (frame_buffer.seq + 1) % 65536 then
                    error_counts.frame_seq_gap = error_counts.frame_seq_gap + 1
                    if error_counts.frame_seq_gap == 1 then
                        log("VP8 frame: sequence gap detected (first: expected "..tostring((frame_buffer.seq + 1) % 65536)..", got "..tostring(seq)..")")
                    end
                    frame_buffer.complete = false
                    return
                end
                
                frame_buffer.seq = seq
                table.insert(frame_buffer.payloads, vp8_frame_data)
            end
        end
		
        local function on_ordered_vp8_payload(seq, vp8_payload, timestamp)
            handle_vp8_payload(seq, vp8_payload, timestamp)
        end
        
        local function on_jitter_buffer_output()
            table.sort(seq_payload_table, seq_compare)
            
            if #seq_payload_table > 0 then
                local entry = seq_payload_table[1]
                on_ordered_vp8_payload(entry.key, entry.value, entry.timestamp)
                table.remove(seq_payload_table, 1)
            end
        end
        
        local function jitter_buffer_finalize() 
            for i, obj in ipairs(seq_payload_table) do
                on_ordered_vp8_payload(obj.key, obj.value, obj.timestamp)
            end
            
            -- Dump last frame if exists
            if frame_buffer ~= nil and frame_buffer.complete then
                dump_complete_frame(frame_buffer)
            end
        end
        
        local function on_vp8_rtp_payload(seq, payload, timestamp)
            local cur_seq = seq.value
            if packet_count == 0 then
                pre_seq = cur_seq
            else
                if cur_seq == pre_seq then
                    packet_count = packet_count + 1
                    return
                else
                    pre_seq = cur_seq
                end
            end

            packet_count = packet_count + 1
            table.insert(seq_payload_table, { 
                key = tonumber(seq.value), 
                value = payload.value,
                timestamp = tonumber(timestamp.value)
            })
            
            if #seq_payload_table > MAX_JITTER_SIZE then
                on_jitter_buffer_output()
            end
        end
        
        function vp8_tap.packet(pinfo, tvb)
            local payloadTable = { vp8_data() }
            local seqTable = { rtp_seq() }
            local timestampTable = { rtp_timestamp() }
            
            if (#payloadTable) < (#seqTable) then 
                error_counts.payload_mismatch = error_counts.payload_mismatch + 1
                if error_counts.payload_mismatch == 1 then
                    log("ERROR: payloadTable size is "..tostring(#payloadTable)..", seqTable size is "..tostring(#seqTable).." (first occurrence)")
                end
                return
            end
            
            if pass == 0 then 
                for i, payload in ipairs(payloadTable) do
                    max_packet_count = max_packet_count + 1
                end
            else 
                for i, payload in ipairs(payloadTable) do
                    on_vp8_rtp_payload(seqTable[1], payload, timestampTable[1])
                end
                
                if packet_count == max_packet_count then
                    jitter_buffer_finalize()
                end
            end 
        end
		
        function vp8_tap.reset()
        end
		
        function vp8_tap.draw() 
        end
		
        local function remove() 
            if fp then 
                fp:close()
                fp = nil
            end
            vp8_tap:remove()
        end 
		
        log("Start VP8 extraction")
        text_window:set_atclose(remove)
		
        log("Phase 1: Counting packets")
        pass = 0
        retap_packets()
        
        log("Phase 2: Extracting stream (max_packet_count = "..tostring(max_packet_count)..")")
        pass = 1
        retap_packets()

        if fp ~= nil then
            -- Update frame count in IVF header
            fp:seek("set", 24)
            fp:write(string.char(
                bit.band(frame_count, 0xFF),
                bit.band(bit.rshift(frame_count, 8), 0xFF),
                bit.band(bit.rshift(frame_count, 16), 0xFF),
                bit.band(bit.rshift(frame_count, 24), 0xFF)
            ))
            fp:close()
            fp = nil
            log("VP8 video stream written to " .. filename .. " ("..tostring(frame_count).." frames)")
        end
        
        -- 汇总错误统计
        log("\n=== Error Summary ===")
        if error_counts.frame_continuation_no_start > 0 then
            log("Frame continuations without start: " .. tostring(error_counts.frame_continuation_no_start) .. " occurrences")
        end
        if error_counts.frame_seq_gap > 0 then
            log("Frame sequence gaps: " .. tostring(error_counts.frame_seq_gap) .. " occurrences")
        end
        if error_counts.incomplete_frame > 0 then
            log("Incomplete frames dropped: " .. tostring(error_counts.incomplete_frame) .. " occurrences")
        end
        if error_counts.payload_mismatch > 0 then
            log("Payload/Seq table mismatches: " .. tostring(error_counts.payload_mismatch) .. " occurrences")
        end
        local total_errors = error_counts.frame_continuation_no_start + error_counts.frame_seq_gap + 
                             error_counts.incomplete_frame + error_counts.payload_mismatch
        if total_errors == 0 then
            log("No errors encountered during extraction")
        else
            log("Total errors: " .. tostring(total_errors))
        end
        
        log("Extraction complete")
	end

    register_menu("Extract VP8 stream from RTP", extract_vp8_from_rtp, MENU_TOOLS_UNSORTED)
end
