feat: implement theme stripping system with THEME_MIN_CARDS config

This commit is contained in:
matt 2026-03-19 15:05:40 -07:00
parent 1ebc2fcb3c
commit 86ece36012
20 changed files with 6604 additions and 1364 deletions

View file

@ -6,6 +6,7 @@ from collections import Counter
from typing import Dict, List, Set, Any
import pandas as pd
import numpy as np
import itertools
import math
try:
@ -20,6 +21,7 @@ if ROOT not in sys.path:
from code.settings import CSV_DIRECTORY
from code.tagging import tag_constants
from code.path_util import get_processed_cards_path
BASE_COLORS = {
'white': 'W',
@ -88,83 +90,113 @@ def collect_theme_tags_from_tagger_source() -> Set[str]:
def tally_tag_frequencies_by_base_color() -> Dict[str, Dict[str, int]]:
"""
Tally theme tag frequencies by base color from parquet files.
Note: This function now reads from card_files/processed/all_cards.parquet
instead of per-color CSV files. The CSV files no longer exist after the
parquet migration.
Returns:
Dictionary mapping color names to Counter of tag frequencies
"""
result: Dict[str, Dict[str, int]] = {c: Counter() for c in BASE_COLORS.keys()}
# Iterate over per-color CSVs; if not present, skip
for color in BASE_COLORS.keys():
path = os.path.join(CSV_DIRECTORY, f"{color}_cards.csv")
if not os.path.exists(path):
# Load from all_cards.parquet
parquet_path = get_processed_cards_path()
if not os.path.exists(parquet_path):
print(f"Warning: Parquet file not found: {parquet_path}")
return {k: dict(v) for k, v in result.items()}
try:
df = pd.read_parquet(parquet_path, columns=['themeTags', 'colorIdentity'], engine='pyarrow')
except Exception as e:
print(f"Error reading parquet file: {e}")
return {k: dict(v) for k, v in result.items()}
if 'themeTags' not in df.columns:
print("Warning: themeTags column not found in parquet file")
return {k: dict(v) for k, v in result.items()}
# Iterate rows and tally tags by base color
for _, row in df.iterrows():
# Parquet stores themeTags as numpy array
tags = row.get('themeTags')
if not isinstance(tags, (list, np.ndarray)):
continue
try:
df = pd.read_csv(path, converters={'themeTags': pd.eval, 'colorIdentity': pd.eval})
except Exception:
df = pd.read_csv(path)
if 'themeTags' in df.columns:
try:
df['themeTags'] = df['themeTags'].apply(pd.eval)
except Exception:
df['themeTags'] = df['themeTags'].apply(lambda x: [])
if 'colorIdentity' in df.columns:
try:
df['colorIdentity'] = df['colorIdentity'].apply(pd.eval)
except Exception:
pass
if 'themeTags' not in df.columns:
if isinstance(tags, np.ndarray):
tags = tags.tolist()
# Get color identity (stored as string like "W", "UB", "WUG", etc.)
ci = row.get('colorIdentity')
if isinstance(ci, np.ndarray):
ci = ci.tolist()
# Convert colorIdentity to set of letters
if isinstance(ci, str):
letters = set(ci) # "WUG" -> {'W', 'U', 'G'}
elif isinstance(ci, list):
letters = set(ci) # ['W', 'U', 'G'] -> {'W', 'U', 'G'}
else:
letters = set()
# Determine base colors from color identity
bases = {name for name, letter in BASE_COLORS.items() if letter in letters}
if not bases:
# Colorless cards don't contribute to any specific color
continue
# Derive base colors from colorIdentity if available, else assume single color file
def rows_base_colors(row):
ids = row.get('colorIdentity') if isinstance(row, dict) else row
if isinstance(ids, list):
letters = set(ids)
else:
letters = set()
derived = set()
for name, letter in BASE_COLORS.items():
if letter in letters:
derived.add(name)
if not derived:
derived.add(color)
return derived
# Iterate rows
for _, row in df.iterrows():
tags = list(row['themeTags']) if hasattr(row.get('themeTags'), '__len__') and not isinstance(row.get('themeTags'), str) else []
# Compute base colors contribution
ci = row['colorIdentity'] if 'colorIdentity' in row else None
letters = set(ci) if isinstance(ci, list) else set()
bases = {name for name, letter in BASE_COLORS.items() if letter in letters}
if not bases:
bases = {color}
for bc in bases:
for t in tags:
result[bc][t] += 1
# Tally tags for each base color this card belongs to
for base_color in bases:
for tag in tags:
if isinstance(tag, str) and tag:
result[base_color][tag] += 1
# Convert Counters to plain dicts
return {k: dict(v) for k, v in result.items()}
def gather_theme_tag_rows() -> List[List[str]]:
"""Collect per-card themeTags lists across all base color CSVs.
"""
Collect per-card themeTags lists from parquet file.
Note: This function now reads from card_files/processed/all_cards.parquet
instead of per-color CSV files. The CSV files no longer exist after the
parquet migration.
Returns a list of themeTags arrays, one per card row where themeTags is present.
Returns:
List of themeTags arrays, one per card row where themeTags is present.
"""
rows: List[List[str]] = []
for color in BASE_COLORS.keys():
path = os.path.join(CSV_DIRECTORY, f"{color}_cards.csv")
if not os.path.exists(path):
continue
try:
df = pd.read_csv(path, converters={'themeTags': pd.eval})
except Exception:
df = pd.read_csv(path)
if 'themeTags' in df.columns:
try:
df['themeTags'] = df['themeTags'].apply(pd.eval)
except Exception:
df['themeTags'] = df['themeTags'].apply(lambda x: [])
if 'themeTags' not in df.columns:
continue
for _, row in df.iterrows():
tags = list(row['themeTags']) if hasattr(row.get('themeTags'), '__len__') and not isinstance(row.get('themeTags'), str) else []
if tags:
rows.append(tags)
# Load from all_cards.parquet
parquet_path = get_processed_cards_path()
if not os.path.exists(parquet_path):
print(f"Warning: Parquet file not found: {parquet_path}")
return rows
try:
df = pd.read_parquet(parquet_path, columns=['themeTags'], engine='pyarrow')
except Exception as e:
print(f"Error reading parquet file: {e}")
return rows
if 'themeTags' not in df.columns:
print("Warning: themeTags column not found in parquet file")
return rows
# Collect theme tags from each card
for _, row in df.iterrows():
# Parquet stores themeTags as numpy array
tags = row.get('themeTags')
if isinstance(tags, np.ndarray):
tags = tags.tolist()
if isinstance(tags, list) and tags:
# Convert to list of strings (filter out non-strings)
tag_list = [str(t) for t in tags if isinstance(t, str) and t]
if tag_list:
rows.append(tag_list)
return rows