Skip to main content

Caption Extraction Script

Complete Python script for extracting filename and caption columns from Visual Layer’s internal parquet files.

Back to Usage Guide

See the main guide for detailed usage instructions, workflow, and examples.

Installation

pip install pandas pyarrow

Quick Usage

# Basic usage
python3 process_annotations.py /.vl/tmp/[dataset-id]/input/metadata/image_annotations.parquet

# Specify output location
python3 process_annotations.py input.parquet -o /path/to/output.parquet

# Custom prefix removal
python3 process_annotations.py input.parquet --prefix /custom/prefix

Script Code

#!/usr/bin/env python3
"""
Process parquet annotation files to extract filename and caption columns.
Removes path prefixes from filenames.
"""

import argparse
import sys
from pathlib import Path
import pandas as pd


def process_parquet(input_path, output_path=None, prefix_to_remove='/hostfs'):
    """
    Process a parquet file to extract filename and caption columns.

    Args:
        input_path: Path to input parquet file
        output_path: Path to output parquet file (optional)
        prefix_to_remove: Prefix to remove from filenames (default: '/hostfs')

    Returns:
        Path to output file
    """
    # Validate input file
    input_file = Path(input_path)
    if not input_file.exists():
        raise FileNotFoundError(f"Input file not found: {input_path}")

    if not input_file.suffix == '.parquet':
        raise ValueError(f"Input file must be a parquet file, got: {input_file.suffix}")

    # Determine output path
    if output_path is None:
        output_file = input_file.parent / f"{input_file.stem}_processed.parquet"
    else:
        output_file = Path(output_path)

    print(f"Reading parquet file: {input_file}")

    # Read parquet file
    try:
        df = pd.read_parquet(input_file)
    except Exception as e:
        raise RuntimeError(f"Failed to read parquet file: {e}")

    print(f"Original shape: {df.shape}")
    print(f"Columns: {df.columns.tolist()}")

    # Validate required columns exist
    required_columns = ['filename', 'caption']
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")

    # Select only filename and caption columns
    df = df[['filename', 'caption']]

    # Remove prefix from filename
    print(f"Removing prefix '{prefix_to_remove}' from filenames...")
    df['filename'] = df['filename'].apply(
        lambda x: x.replace(prefix_to_remove, '', 1)
        if isinstance(x, str) and x.startswith(prefix_to_remove)
        else x
    )

    # Show sample of processed data
    print(f"\nProcessed shape: {df.shape}")
    print(f"\nSample filenames after processing:")
    print(df['filename'].head(3).tolist())

    # Save to output file
    print(f"\nSaving to: {output_file}")
    df.to_parquet(output_file, index=False)

    print(f"✓ Successfully processed {len(df)} rows")
    print(f"✓ Output saved to: {output_file}")

    return output_file


def main():
    parser = argparse.ArgumentParser(
        description='Process parquet annotation files to extract filename and caption columns.',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  %(prog)s input.parquet
  %(prog)s input.parquet -o output.parquet
  %(prog)s input.parquet --prefix //hostfs/mnt
        """
    )

    parser.add_argument(
        'input',
        help='Path to input parquet file'
    )

    parser.add_argument(
        '-o', '--output',
        help='Path to output parquet file (default: <input>_processed.parquet)'
    )

    parser.add_argument(
        '--prefix',
        default='/hostfs',
        help='Prefix to remove from filenames (default: /hostfs)'
    )

    args = parser.parse_args()

    try:
        process_parquet(args.input, args.output, args.prefix)
        return 0
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        return 1


if __name__ == '__main__':
    sys.exit(main())