|
| 1 | +#!/usr/bin/env python |
| 2 | +""" |
| 3 | +ArXiv category code to user-friendly name converter. |
| 4 | +Called by arxiv_fetch.py to convert category codes to readable names. |
| 5 | +""" |
| 6 | +import csv |
| 7 | +import os |
| 8 | +import yaml |
| 9 | + |
| 10 | +def load_category_mapping(data_dir): |
| 11 | + """Load category code to label mapping from YAML file.""" |
| 12 | + mapping_file = os.path.join(data_dir, "arxiv_category_map.yaml") |
| 13 | + |
| 14 | + if not os.path.exists(mapping_file): |
| 15 | + return {} |
| 16 | + |
| 17 | + try: |
| 18 | + with open(mapping_file, 'r') as f: |
| 19 | + return yaml.safe_load(f) or {} |
| 20 | + except Exception: |
| 21 | + return {} |
| 22 | + |
| 23 | +def convert_categories_to_friendly_names(input_file, output_file, data_dir): |
| 24 | + """ |
| 25 | + Convert category codes in CSV to user-friendly names. |
| 26 | + |
| 27 | + Args: |
| 28 | + input_file: Path to input CSV with category codes |
| 29 | + output_file: Path to output CSV with friendly names |
| 30 | + data_dir: Directory containing arxiv_category_map.yaml |
| 31 | + """ |
| 32 | + if not os.path.exists(input_file): |
| 33 | + return |
| 34 | + |
| 35 | + # Load category mapping |
| 36 | + category_mapping = load_category_mapping(data_dir) |
| 37 | + |
| 38 | + with open(input_file, 'r') as infile, open(output_file, 'w', newline='') as outfile: |
| 39 | + reader = csv.DictReader(infile) |
| 40 | + |
| 41 | + # Create new fieldnames with both code and label |
| 42 | + fieldnames = [] |
| 43 | + for field in reader.fieldnames: |
| 44 | + fieldnames.append(field) |
| 45 | + if field == 'CATEGORY': |
| 46 | + fieldnames.append('CATEGORY_LABEL') |
| 47 | + |
| 48 | + writer = csv.DictWriter(outfile, fieldnames=fieldnames, dialect='unix') |
| 49 | + writer.writeheader() |
| 50 | + |
| 51 | + for row in reader: |
| 52 | + if 'CATEGORY' in row: |
| 53 | + category_code = row['CATEGORY'] |
| 54 | + # Convert code to label, fallback to uppercase first part if not found |
| 55 | + category_label = category_mapping.get( |
| 56 | + category_code, |
| 57 | + category_code.split('.')[0].upper() if category_code and '.' in category_code else category_code |
| 58 | + ) |
| 59 | + row['CATEGORY_LABEL'] = category_label |
| 60 | + |
| 61 | + writer.writerow(row) |
0 commit comments