#!/usr/bin/env python3
"""OpenAlex literature search via curl — no external dependencies.

Usage:
    python search_openalex.py <query> [--domain D3] [--oa] [--from YEAR] [--sort cited] [--max N]
    python search_openalex.py "digital art philosophy" --domain D3 --max 10
    python search_openalex.py "smart grid energy" --domain D4 --from 2020 --sort cited

Domains: D3=Arts, D2=Philosophy, D1=SocialSci, D4=Engineering, D6=CS/AI, D5=Math
"""
import sys
import json
import subprocess
import shlex
import urllib.parse

BASE = "https://api.openalex.org"


def search(query, domain=None, year_from=None, oa_only=False, max_results=10, sort="relevance"):
    params = [
        ("search", query),
        ("per_page", min(max_results, 100)),
    ]
    filters = []
    if domain:
        filters.append(f"primary_topic.domain.id:{domain}")
    if year_from:
        filters.append(f"from_publication_date:{year_from}-01-01")
    if oa_only:
        filters.append("is_oa:true")
    if filters:
        params.append(("filter", ",".join(filters)))

    sort_map = {
        "relevance": "relevance_score:desc",
        "cited": "cited_by_count:desc",
        "date": "publication_date:desc",
    }
    params.append(("sort", sort_map.get(sort, sort_map["relevance"])))

    query_str = "&".join(f"{k}={urllib.parse.quote(str(v))}" for k, v in params)
    url = f"{BASE}/works?{query_str}"

    cmd = ["curl", "-s", "-L", "--max-time", "20",
           "-H", "User-Agent: HermesAgent/1.0 (mailto:kuhnn@example.com)",
           "-H", "Accept: application/json",
           url]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        print(f"Error: curl failed — {result.stderr.strip()}", file=sys.stderr)
        sys.exit(1)
    try:
        return json.loads(result.stdout)
    except json.JSONDecodeError as e:
        print(f"Error parsing response: {e}", file=sys.stderr)
        print(f"Response: {result.stdout[:500]}", file=sys.stderr)
        sys.exit(1)


def print_results(data, show_abstract=False):
    results = data.get("results", [])
    count = data.get("meta", {}).get("count", 0)
    print(f"\nFound {count} papers (showing {len(results)})\n")
    for w in results:
        title = w.get("title", "Untitled")
        # Handle different authorship formats
        raw_authors = w.get("authorships", [])
        authors = []
        for a in raw_authors[:3]:
            if isinstance(a, dict):
                authors.append(a.get("display_name", a.get("author", {}).get("display_name", "Unknown")))
            else:
                authors.append(str(a))
        author_str = ", ".join(authors)
        if len(raw_authors) > 3:
            author_str += " et al."
        year = w.get("publication_year", "?")
        cited = w.get("cited_by_count", 0)
        oa = " [Open Access]" if w.get("is_oa") else ""
        doi = w.get("doi", "")
        domain = w.get("primary_topic", {}).get("domain", {}).get("display_name", "")
        print(f"[{year}] {title}")
        print(f"  Authors: {author_str}")
        print(f"  Cited: {cited}x  |  Field: {domain}{oa}")
        print(f"  {doi}")
        if show_abstract:
            abstract_inv = w.get("abstract_inverted_index")
            if abstract_inv:
                words = []
                for k, positions in sorted(abstract_inv.items()):
                    for pos in positions:
                        words.append((pos, k))
                words.sort()
                text = " ".join(word for _, word in words)
                print(f"  Abstract: {text[:300]}...")
        print()


if __name__ == "__main__":
    if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
        print(__doc__)
        sys.exit(0)

    q = sys.argv[1]
    domain, year_from, oa_only, max_n, sort = None, None, False, 10, "relevance"
    show_abstract = False

    i = 2
    while i < len(sys.argv):
        arg = sys.argv[i]
        if arg == "--domain" and i + 1 < len(sys.argv):
            domain = sys.argv[i + 1]
            i += 2
        elif arg == "--from" and i + 1 < len(sys.argv):
            year_from = sys.argv[i + 1]
            i += 2
        elif arg == "--oa":
            oa_only = True
            i += 1
        elif arg == "--max" and i + 1 < len(sys.argv):
            max_n = int(sys.argv[i + 1])
            i += 2
        elif arg == "--sort" and i + 1 < len(sys.argv):
            sort = sys.argv[i + 1]
            i += 2
        elif arg == "--abstract":
            show_abstract = True
            i += 1
        else:
            i += 1

    data = search(q, domain=domain, year_from=year_from, oa_only=oa_only,
                  max_results=max_n, sort=sort)
    print_results(data, show_abstract=show_abstract)
