From 40276cb0b41cffa6efd91ce082f4b6fcb96198d2 Mon Sep 17 00:00:00 2001 From: Andreas Ruden Date: Fri, 27 Oct 2023 13:48:59 +0000 Subject: [PATCH] Add basic sqlite3 saving & loading support --- d2warehouse/db.py | 38 +++++++ d2warehouse/fileformat.py | 1 + d2warehouse/item.py | 199 ++++++++++++++++++++++++++++++++++- d2warehouse/parser.py | 5 +- d2warehouse/schema.sql | 85 +++++++++++++++ d2warehouse/tests/test_db.py | 31 ++++++ pyproject.toml | 2 +- 7 files changed, 354 insertions(+), 7 deletions(-) create mode 100644 d2warehouse/db.py create mode 100644 d2warehouse/schema.sql create mode 100644 d2warehouse/tests/test_db.py diff --git a/d2warehouse/db.py b/d2warehouse/db.py new file mode 100644 index 0000000..9708622 --- /dev/null +++ b/d2warehouse/db.py @@ -0,0 +1,38 @@ +import atexit +import os +import sqlite3 + + +_schema_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "schema.sql") +_db = None +_path = "items.sqlite3" + + +def set_db_path(f: str) -> None: + global _path + _path = f + + +def get_db() -> sqlite3.Connection: + global _db + if not _db: + _db = sqlite3.connect( + _path, + detect_types=sqlite3.PARSE_DECLTYPES, + ) + _db.row_factory = sqlite3.Row + return _db + + +@atexit.register +def close_db() -> None: + global _db + if _db: + _db.close() + _db = None + + +def create_db() -> None: + db = get_db() + with open(_schema_path, encoding="utf-8") as f: + db.executescript(f.read()) diff --git a/d2warehouse/fileformat.py b/d2warehouse/fileformat.py index 23369bd..e3bfbaf 100644 --- a/d2warehouse/fileformat.py +++ b/d2warehouse/fileformat.py @@ -16,4 +16,5 @@ # Mercator. If not, see . STASH_TAB_MAGIC = b"\x55\xAA\x55\xAA" +STASH_TAB_VERSION = 99 ITEM_DATA_MAGIC = b"JM" diff --git a/d2warehouse/item.py b/d2warehouse/item.py index e1908f3..bc759b0 100644 --- a/d2warehouse/item.py +++ b/d2warehouse/item.py @@ -21,7 +21,9 @@ from typing import Optional from bitarray import bitarray from bitarray.util import int2ba from dataclasses import dataclass -from enum import Enum +from enum import IntEnum +from d2warehouse.fileformat import STASH_TAB_VERSION +from d2warehouse.db import get_db _data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") _basetype_map = None @@ -31,7 +33,7 @@ _set_item_map = None _runeword_map = None -class Quality(Enum): +class Quality(IntEnum): LOW = 1 NORMAL = 2 HIGH = 3 @@ -45,7 +47,7 @@ class Quality(Enum): return self.name.capitalize() -class LowQualityType(Enum): +class LowQualityType(IntEnum): CRUDE = 0 CRACKED = 1 DAMAGED = 2 @@ -60,7 +62,7 @@ class Stat: id: int | None = None # TODO: These 3 should probably not be optional values: list[int] | None = None parameter: int | None = None - text: str | None = None + text: str | None = None # TODO: Make this a property def print(self, indent=5): print(" " * indent, str(self)) @@ -88,6 +90,7 @@ def txtbits(bits: bitarray) -> str: @dataclass class Item: raw_data: bytes + raw_version: int is_identified: bool is_socketed: bool is_beginner: bool @@ -201,6 +204,194 @@ class Item: base = lookup_basetype(self.code) return base["width"], base["height"] + def write_to_db(self, socketed_into=None, commit=True) -> int: + name = lookup_basetype(self.code)["name"] + # FIXME: handle magic & rare names + if self.is_runeword: + name = lookup_runeword(self.runeword_id)["name"] + elif self.quality == Quality.SET: + name = lookup_set_item(self.set_id)["name"] + elif self.quality == Quality.UNIQUE: + name = lookup_unique(self.unique_id)["name"] + + set_name = ( + lookup_set_item(self.set_id)["set"] if self.quality == Quality.SET else None + ) + + db = get_db() + cur = db.cursor() + cur.execute( + """INSERT INTO item (itembase_name, socketed_into, raw_data, raw_version, + is_identified, is_socketed, is_beginner, is_simple, is_ethereal, + is_personalized, is_runeword, code) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + lookup_basetype(self.code)["name"], + socketed_into, + self.raw_data, + self.raw_version, + self.is_identified, + self.is_socketed, + self.is_beginner, + self.is_simple, + self.is_ethereal, + self.is_personalized, + self.is_runeword, + self.code, + ), + ) + item_id = cur.lastrowid + + cur.execute( + """INSERT INTO item_extra (item_id, item_name, + set_name, uid, lvl, quality, graphic, implicit, low_quality, set_id, + unique_id, nameword1, nameword2, runeword_id, personal_name, + defense, durability, max_durability, quantity) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + item_id, + name, + set_name, + self.uid, + self.lvl, + int(self.quality) if self.quality else None, + self.graphic, + self.implicit, + int(self.low_quality) if self.low_quality else None, + self.set_id, + self.unique_id, + self.nameword1, + self.nameword2, + self.runeword_id, + self.personal_name, + self.defense, + self.durability, + self.max_durability, + self.quantity, + ), + ) + + if self.quality in [Quality.MAGIC, Quality.RARE, Quality.CRAFTED]: + for prefix, id in [(False, id) for id in self.suffixes] + [ + (True, id) for id in self.prefixes + ]: + db.execute( + "INSERT INTO item_affix (item_id, prefix, affix_id) VALUES (?, ?, ?)", + (item_id, prefix, id), + ) + + if self.stats: + for stat in self.stats: + db.execute( + """INSERT INTO item_stat (item_id, stat, value1, value2, value3, parameter) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + item_id, + stat.id, + stat.values[0] if len(stat.values) > 0 else None, + stat.values[1] if len(stat.values) > 1 else None, + stat.values[2] if len(stat.values) > 2 else None, + stat.parameter, + ), + ) + + if self.sockets: + for socket in self.sockets: + socket.write_to_db(socketed_into=item_id, commit=False) + + if commit: + db.commit() + + return item_id + + def load_from_db(id: int) -> "Item": + db = get_db() + row = db.execute( + """SELECT raw_data, raw_version, is_identified, is_socketed, + is_beginner, is_simple, is_ethereal, is_personalized, is_runeword, code, + uid, lvl, quality, graphic, implicit, low_quality, set_id, unique_id, + nameword1, nameword2, runeword_id, personal_name, defense, durability, + max_durability, quantity + FROM item INNER JOIN item_extra ON id = item_id WHERE id = ?""", + (id,), + ).fetchone() + if row["raw_version"] != STASH_TAB_VERSION: + raise RuntimeError("Can not load item, the raw version is not supported") + + item = Item( + raw_data=row["raw_data"], + raw_version=row["raw_version"], + is_identified=bool(row["is_identified"]), + is_socketed=bool(row["is_socketed"]), + is_beginner=bool(row["is_beginner"]), + is_simple=bool(row["is_simple"]), + is_ethereal=bool(row["is_ethereal"]), + is_personalized=bool(row["is_personalized"]), + is_runeword=bool(row["is_runeword"]), + pos_x=0, + pos_y=0, + code=row["code"], + uid=row["uid"], + lvl=row["lvl"], + quality=Quality(row["quality"]) if row["quality"] else None, + graphic=row["graphic"], + implicit=row["implicit"], + low_quality=LowQualityType(row["low_quality"]) + if row["low_quality"] + else None, + set_id=row["set_id"], + unique_id=row["unique_id"], + nameword1=row["nameword1"], + nameword2=row["nameword2"], + runeword_id=row["runeword_id"], + personal_name=row["personal_name"], + defense=row["defense"], + durability=row["durability"], + max_durability=row["max_durability"], + quantity=row["quantity"], + ) + + if item.quality in [Quality.MAGIC, Quality.RARE, Quality.CRAFTED]: + rows = db.execute( + "SELECT prefix, affix_id FROM item_affix WHERE item_id = ?", (id,) + ) + item.prefixes = [] + item.suffixes = [] + for row in rows: + if row["prefix"]: + item.prefixes.append(row["affix_id"]) + else: + item.suffixes.append(row["affix_id"]) + + rows = db.execute( + "SELECT id FROM item WHERE socketed_into = ?", (id,) + ).fetchall() + if len(rows) > 0: + item.sockets = [] + for row in rows: + socket = Item.load_from_db(row["id"]) + socket.pos_x = len(item.sockets) + item.sockets.append(socket) + + rows = db.execute( + "SELECT stat, value1, value2, value3, parameter FROM item_stat WHERE item_id = ?", + (id,), + ).fetchall() + if len(rows) > 0: + item.stats = [] + for row in rows: + values = [] + for i in range(1, 4): + if row[f"value{i}"] is not None: + values.append(row[f"value{i}"]) + stat = Stat(id=row["stat"], values=values, parameter=row["parameter"]) + stat_data = lookup_stat(stat.id) + stat.text = stat_data["text"] + item.stats.append(stat) + + return item + def lookup_basetype(code: str) -> dict: global _basetype_map diff --git a/d2warehouse/parser.py b/d2warehouse/parser.py index 50f95d5..73f972b 100644 --- a/d2warehouse/parser.py +++ b/d2warehouse/parser.py @@ -27,7 +27,7 @@ from d2warehouse.item import ( lookup_stat, ) import d2warehouse.huffman as huffman -from d2warehouse.fileformat import STASH_TAB_MAGIC, ITEM_DATA_MAGIC +from d2warehouse.fileformat import STASH_TAB_MAGIC, STASH_TAB_VERSION, ITEM_DATA_MAGIC class ParseError(RuntimeError): @@ -82,7 +82,7 @@ def parse_stash_tab(data: bytes) -> tuple[bytes, StashTab]: if unknown != 1: ParseError("Unknown stash tab field is not 1") - if version != 99: + if version != STASH_TAB_VERSION: ParseError(f"Unsupported stash tab version ({version} instead of 99)") tab = StashTab() @@ -133,6 +133,7 @@ def parse_item(data: bytes) -> tuple[bytes, Item]: simple_byte_sz = int((sockets_end + 7) / 8) item = Item( data[:simple_byte_sz], + STASH_TAB_VERSION, is_identified, is_socketed, is_beginner, diff --git a/d2warehouse/schema.sql b/d2warehouse/schema.sql new file mode 100644 index 0000000..7ba639a --- /dev/null +++ b/d2warehouse/schema.sql @@ -0,0 +1,85 @@ +DROP TABLE IF EXISTS item_stat; +DROP TABLE IF EXISTS item_affix; +DROP TABLE IF EXISTS item_extra; +DROP TABLE IF EXISTS item; + +CREATE TABLE item ( + id INTEGER PRIMARY KEY, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted TIMESTAMP DEFAULT NULL, + -- Nuked: if the item has been removed from storage & user indicated he does not + -- want to count it as potentially in his possession any longer + nuked BOOLEAN DEFAULT NULL, + -- Names: simple items have only a name equal to their base name, most non simple + -- items have two names: the base name and item_extra.item_name. + itembase_name TEXT NOT NULL, + socketed_into INTEGER DEFAULT NULL, + + -- The following fields match the fields of the item object in item.py + raw_data BLOB NOT NULL, + raw_version INTEGER NOT NULL, + is_identified BOOLEAN NOT NULL, + is_socketed BOOLEAN NOT NULL, + is_beginner BOOLEAN NOT NULL, + is_simple BOOLEAN NOT NULL, + is_ethereal BOOLEAN NOT NULL, + is_personalized BOOLEAN NOT NULL, + is_runeword BOOLEAN NOT NULL, + code TEXT NOT NULL, + + FOREIGN KEY (socketed_into) REFERENCES item (id) +); +-- Add an index for "... WHERE deletion IS NULL" +CREATE INDEX item_deletion_partial ON item (deleted) WHERE deleted IS NULL; +-- * nuked: if the item has been removed from storage & user indicated he does not +-- want to count it as potentially in his possession any longer + +CREATE TABLE item_extra ( + item_id INTEGER PRIMARY KEY, + item_name TEXT DEFAULT NULL, + set_name TEXT DEFAULT NULL, + + -- The following fields match the fields of the item object in item.py + uid INTEGER DEFAULT NULL, + lvl INTEGER DEFAULT NULL, + quality INTEGER DEFAULT NULL CHECK(1 <= quality AND quality <= 8), + graphic INTEGER DEFAULT NULL, + implicit INTEGER DEFAULT NULL, + low_quality INTEGER DEFAULT NULL CHECK(0 <= low_quality AND low_quality <= 3), + set_id INTEGER DEFAULT NULL, + unique_id INTEGER DEFAULT NULL, + nameword1 INTEGER DEFAULT NULL, + nameword2 INTEGER DEFAULT NULL, + runeword_id INTEGER DEFAULT NULL, + personal_name TEXT DEFAULT NULL, + defense INTEGER DEFAULT NULL, + durability INTEGER DEFAULT NULL, + max_durability INTEGER DEFAULT NULL, + -- sockets: list[Optional["Item"]] | None = None => see item.socketed_into + quantity INTEGER DEFAULT NULL, + -- stats: list[Stat] | None = None => see table 'item_stat' + + FOREIGN KEY (item_id) REFERENCES item (id) +); + +CREATE TABLE item_stat ( + id INTEGER PRIMARY KEY, + item_id INTEGER NOT NULL, + stat INTEGER NOT NULL, + value1 INTEGER DEFAULT NULL, + value2 INTEGER DEFAULT NULL, + value3 INTEGER DEFAULT NULL, + parameter INTEGER DEFAULT NULL, + + FOREIGN KEY (item_id) REFERENCES item (id) +); +CREATE INDEX item_stat_stat ON item_stat (stat); + +CREATE TABLE item_affix ( + id INTEGER PRIMARY KEY, + item_id INTEGER NOT NULL, + prefix BOOLEAN NOT NULL, + affix_id INTEGER NOT NULL, + + FOREIGN KEY (item_id) REFERENCES item (id) +); diff --git a/d2warehouse/tests/test_db.py b/d2warehouse/tests/test_db.py new file mode 100644 index 0000000..3081ebf --- /dev/null +++ b/d2warehouse/tests/test_db.py @@ -0,0 +1,31 @@ +import os +import tempfile +import unittest + +from d2warehouse.db import close_db, create_db, set_db_path +from d2warehouse.item import Item +from d2warehouse.parser import parse_item + + +class DbTest(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls._fd, cls._path = tempfile.mkstemp() + set_db_path(cls._path) + create_db() + + @classmethod + def tearDownClass(cls): + close_db() + os.close(cls._fd) + os.unlink(cls._path) + + def test_runeword_lore(self): + data = bytes.fromhex( + "10088004050054c637f1073af4558697412981070881506049e87f005516fb134582ff1000a0003500e07cbb001000a0003504e07c9800" + ) + _, item = parse_item(data) + + db_id = item.write_to_db() + loaded_item = Item.load_from_db(db_id) + self.assertEqual(item, loaded_item) diff --git a/pyproject.toml b/pyproject.toml index 3f02014..87a12ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ d2test = "d2warehouse.test:main" version = {attr = "d2warehouse.__version__"} [tool.setuptools.package-data] -d2warehouse = ["data/*.json"] +d2warehouse = ["data/*.json", "schema.sql"] [tool.pytest.ini_options] addopts = "--cov --cov-report html --cov-report term"