"""Tests for stocks.py - unified stock management.""" import json import sys from datetime import datetime from pathlib import Path from unittest.mock import patch import pytest # Add scripts to path sys.path.insert(0, str(Path(__file__).parent.parent / "scripts")) from stocks import ( load_stocks, save_stocks, get_holdings, get_watchlist, get_holding_tickers, get_watchlist_tickers, add_to_watchlist, add_to_holdings, move_to_holdings, remove_stock, ) @pytest.fixture def sample_stocks_data(): """Sample stocks data for testing.""" return { "version": "1.0", "updated": "2026-01-30", "holdings": [ {"ticker": "AAPL", "name": "Apple Inc.", "category": "Tech"}, {"ticker": "MSFT", "name": "Microsoft", "category": "Tech"}, ], "watchlist": [ {"ticker": "NVDA", "target": 800.0, "notes": "Buy on dip"}, {"ticker": "TSLA", "target": 200.0, "notes": "Watch earnings"}, ], "alert_definitions": {}, } @pytest.fixture def stocks_file(tmp_path, sample_stocks_data): """Create a temporary stocks file.""" stocks_path = tmp_path / "stocks.json" stocks_path.write_text(json.dumps(sample_stocks_data)) return stocks_path class TestLoadStocks: """Tests for load_stocks().""" def test_load_existing_file(self, stocks_file, sample_stocks_data): """Load stocks from existing file.""" data = load_stocks(stocks_file) assert data["version"] == "1.0" assert len(data["holdings"]) == 2 assert len(data["watchlist"]) == 2 def test_load_missing_file(self, tmp_path): """Return default structure when file doesn't exist.""" missing_path = tmp_path / "missing.json" data = load_stocks(missing_path) assert data["version"] == "1.0" assert data["holdings"] == [] assert data["watchlist"] == [] assert "alert_definitions" in data class TestSaveStocks: """Tests for save_stocks().""" def test_save_updates_timestamp(self, tmp_path, sample_stocks_data): """Save should update the 'updated' field.""" stocks_path = tmp_path / "stocks.json" save_stocks(sample_stocks_data, stocks_path) saved = json.loads(stocks_path.read_text()) assert saved["updated"] == datetime.now().strftime("%Y-%m-%d") def test_save_preserves_data(self, tmp_path, sample_stocks_data): """Save should preserve all data.""" stocks_path = tmp_path / "stocks.json" save_stocks(sample_stocks_data, stocks_path) saved = json.loads(stocks_path.read_text()) assert len(saved["holdings"]) == 2 assert saved["holdings"][0]["ticker"] == "AAPL" class TestGetHoldings: """Tests for get_holdings().""" def test_get_holdings_from_data(self, sample_stocks_data): """Get holdings from provided data.""" holdings = get_holdings(sample_stocks_data) assert len(holdings) == 2 assert holdings[0]["ticker"] == "AAPL" def test_get_holdings_empty(self): """Return empty list for empty data.""" data = {"holdings": [], "watchlist": []} holdings = get_holdings(data) assert holdings == [] class TestGetWatchlist: """Tests for get_watchlist().""" def test_get_watchlist_from_data(self, sample_stocks_data): """Get watchlist from provided data.""" watchlist = get_watchlist(sample_stocks_data) assert len(watchlist) == 2 assert watchlist[0]["ticker"] == "NVDA" def test_get_watchlist_empty(self): """Return empty list for empty data.""" data = {"holdings": [], "watchlist": []} watchlist = get_watchlist(data) assert watchlist == [] class TestGetHoldingTickers: """Tests for get_holding_tickers().""" def test_get_holding_tickers(self, sample_stocks_data): """Get set of holding tickers.""" tickers = get_holding_tickers(sample_stocks_data) assert tickers == {"AAPL", "MSFT"} def test_get_holding_tickers_empty(self): """Return empty set for no holdings.""" data = {"holdings": [], "watchlist": []} tickers = get_holding_tickers(data) assert tickers == set() class TestGetWatchlistTickers: """Tests for get_watchlist_tickers().""" def test_get_watchlist_tickers(self, sample_stocks_data): """Get set of watchlist tickers.""" tickers = get_watchlist_tickers(sample_stocks_data) assert tickers == {"NVDA", "TSLA"} def test_get_watchlist_tickers_empty(self): """Return empty set for empty watchlist.""" data = {"holdings": [], "watchlist": []} tickers = get_watchlist_tickers(data) assert tickers == set() class TestAddToWatchlist: """Tests for add_to_watchlist().""" def test_add_new_to_watchlist(self, stocks_file, monkeypatch): """Add new stock to watchlist.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) result = add_to_watchlist("AMD", target=150.0, notes="Watch for dip") assert result is True data = load_stocks(stocks_file) tickers = get_watchlist_tickers(data) assert "AMD" in tickers def test_update_existing_watchlist(self, stocks_file, monkeypatch): """Update existing watchlist entry.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) # NVDA already in watchlist with target 800 result = add_to_watchlist("NVDA", target=750.0, notes="Updated target") assert result is True data = load_stocks(stocks_file) nvda = next(w for w in data["watchlist"] if w["ticker"] == "NVDA") assert nvda["target"] == 750.0 assert nvda["notes"] == "Updated target" def test_add_with_alerts(self, stocks_file, monkeypatch): """Add stock with alert definitions.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) alerts = ["below_target", "above_stop"] result = add_to_watchlist("GOOG", target=180.0, alerts=alerts) assert result is True data = load_stocks(stocks_file) goog = next(w for w in data["watchlist"] if w["ticker"] == "GOOG") assert goog["alerts"] == alerts class TestAddToHoldings: """Tests for add_to_holdings().""" def test_add_new_holding(self, stocks_file, monkeypatch): """Add new stock to holdings.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) result = add_to_holdings("GOOG", name="Alphabet", category="Tech") assert result is True data = load_stocks(stocks_file) tickers = get_holding_tickers(data) assert "GOOG" in tickers def test_update_existing_holding(self, stocks_file, monkeypatch): """Update existing holding.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) result = add_to_holdings("AAPL", name="Apple Inc.", category="Consumer", notes="Core holding") assert result is True data = load_stocks(stocks_file) aapl = next(h for h in data["holdings"] if h["ticker"] == "AAPL") assert aapl["category"] == "Consumer" assert aapl["notes"] == "Core holding" class TestMoveToHoldings: """Tests for move_to_holdings().""" def test_move_from_watchlist(self, stocks_file, monkeypatch): """Move stock from watchlist to holdings.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) # NVDA is in watchlist, not holdings result = move_to_holdings("NVDA", name="NVIDIA Corp", category="Semis") assert result is True data = load_stocks(stocks_file) assert "NVDA" in get_holding_tickers(data) assert "NVDA" not in get_watchlist_tickers(data) def test_move_nonexistent_returns_false(self, stocks_file, monkeypatch): """Moving non-existent ticker returns False.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) result = move_to_holdings("NONEXISTENT") assert result is False class TestRemoveStock: """Tests for remove_stock().""" def test_remove_from_holdings(self, stocks_file, monkeypatch): """Remove stock from holdings.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) result = remove_stock("AAPL", from_list="holdings") assert result is True data = load_stocks(stocks_file) assert "AAPL" not in get_holding_tickers(data) def test_remove_from_watchlist(self, stocks_file, monkeypatch): """Remove stock from watchlist.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) result = remove_stock("NVDA", from_list="watchlist") assert result is True data = load_stocks(stocks_file) assert "NVDA" not in get_watchlist_tickers(data) def test_remove_nonexistent_returns_false(self, stocks_file, monkeypatch): """Removing non-existent ticker returns False.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) result = remove_stock("NONEXISTENT", from_list="holdings") assert result is False def test_remove_auto_detects_list(self, stocks_file, monkeypatch): """Remove without specifying list auto-detects.""" monkeypatch.setattr("stocks.STOCKS_FILE", stocks_file) # AAPL is in holdings result = remove_stock("AAPL") assert result is True data = load_stocks(stocks_file) assert "AAPL" not in get_holding_tickers(data)