From 27917b63a14c777b6993c7f8745c7a679587b7cc Mon Sep 17 00:00:00 2001 From: Andreas Date: Thu, 26 Oct 2023 21:57:13 +0200 Subject: [PATCH] Add basic sqlite3 saving & loading support --- d2warehouse/db.py | 38 +++++++ d2warehouse/item.py | 194 ++++++++++++++++++++++++++++++++++- d2warehouse/schema.sql | 70 +++++++++++++ d2warehouse/tests/test_db.py | 31 ++++++ 4 files changed, 329 insertions(+), 4 deletions(-) create mode 100644 d2warehouse/db.py create mode 100644 d2warehouse/schema.sql create mode 100644 d2warehouse/tests/test_db.py diff --git a/d2warehouse/db.py b/d2warehouse/db.py new file mode 100644 index 0000000..9708622 --- /dev/null +++ b/d2warehouse/db.py @@ -0,0 +1,38 @@ +import atexit +import os +import sqlite3 + + +_schema_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "schema.sql") +_db = None +_path = "items.sqlite3" + + +def set_db_path(f: str) -> None: + global _path + _path = f + + +def get_db() -> sqlite3.Connection: + global _db + if not _db: + _db = sqlite3.connect( + _path, + detect_types=sqlite3.PARSE_DECLTYPES, + ) + _db.row_factory = sqlite3.Row + return _db + + +@atexit.register +def close_db() -> None: + global _db + if _db: + _db.close() + _db = None + + +def create_db() -> None: + db = get_db() + with open(_schema_path, encoding="utf-8") as f: + db.executescript(f.read()) diff --git a/d2warehouse/item.py b/d2warehouse/item.py index e1908f3..a4c8499 100644 --- a/d2warehouse/item.py +++ b/d2warehouse/item.py @@ -21,7 +21,9 @@ from typing import Optional from bitarray import bitarray from bitarray.util import int2ba from dataclasses import dataclass -from enum import Enum +from enum import IntEnum + +from d2warehouse.db import get_db _data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") _basetype_map = None @@ -31,7 +33,7 @@ _set_item_map = None _runeword_map = None -class Quality(Enum): +class Quality(IntEnum): LOW = 1 NORMAL = 2 HIGH = 3 @@ -45,7 +47,7 @@ class Quality(Enum): return self.name.capitalize() -class LowQualityType(Enum): +class LowQualityType(IntEnum): CRUDE = 0 CRACKED = 1 DAMAGED = 2 @@ -60,7 +62,7 @@ class Stat: id: int | None = None # TODO: These 3 should probably not be optional values: list[int] | None = None parameter: int | None = None - text: str | None = None + text: str | None = None # TODO: Make this a property def print(self, indent=5): print(" " * indent, str(self)) @@ -201,6 +203,190 @@ class Item: base = lookup_basetype(self.code) return base["width"], base["height"] + def write_to_db(self, socketed_into=None, commit=True) -> int: + name = lookup_basetype(self.code)["name"] + # FIXME: handle magic & rare names + if self.is_runeword: + name = lookup_runeword(self.runeword_id)["name"] + elif self.quality == Quality.SET: + name = lookup_set_item(self.set_id)["name"] + elif self.quality == Quality.UNIQUE: + name = lookup_unique(self.unique_id)["name"] + + set_name = ( + lookup_set_item(self.set_id)["set"] if self.quality == Quality.SET else None + ) + prefix1 = self.prefixes[0] if self.prefixes and len(self.prefixes) > 0 else None + prefix2 = self.prefixes[1] if self.prefixes and len(self.prefixes) > 1 else None + prefix3 = self.prefixes[2] if self.prefixes and len(self.prefixes) > 2 else None + suffix1 = self.suffixes[0] if self.suffixes and len(self.suffixes) > 0 else None + suffix2 = self.suffixes[1] if self.suffixes and len(self.suffixes) > 1 else None + suffix3 = self.suffixes[2] if self.suffixes and len(self.suffixes) > 2 else None + + db = get_db() + cur = db.cursor() + cur.execute( + """INSERT INTO item (itembase_name, item_name, + set_name, socketed_into, raw_data, is_identified, is_socketed, + is_beginner, is_simple, is_ethereal, is_personalized, is_runeword, pos_x, + pos_y, code, uid, lvl, quality, graphic, implicit, low_quality, prefix1, + prefix2, prefix3, suffix1, suffix2, suffix3, set_id, unique_id, + nameword1, nameword2, runeword_id, personal_name, defense, durability, + max_durability, quantity) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + lookup_basetype(self.code)["name"], + name, + set_name, + socketed_into, + self.raw_data, + self.is_identified, + self.is_socketed, + self.is_beginner, + self.is_simple, + self.is_ethereal, + self.is_personalized, + self.is_runeword, + self.pos_x, + self.pos_y, + self.code, + self.uid, + self.lvl, + int(self.quality) if self.quality else None, + self.graphic, + self.implicit, + int(self.low_quality) if self.low_quality else None, + prefix1, + prefix2, + prefix3, + suffix1, + suffix2, + suffix3, + self.set_id, + self.unique_id, + self.nameword1, + self.nameword2, + self.runeword_id, + self.personal_name, + self.defense, + self.durability, + self.max_durability, + self.quantity, + ), + ) + item_id = cur.lastrowid + + if self.stats: + for stat in self.stats: + db.execute( + """INSERT INTO item_stat (item_id, stat, value1, value2, value3, parameter) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + item_id, + stat.id, + stat.values[0] if len(stat.values) > 0 else None, + stat.values[1] if len(stat.values) > 1 else None, + stat.values[2] if len(stat.values) > 2 else None, + stat.parameter, + ), + ) + + if self.sockets: + for socket in self.sockets: + socket.write_to_db(socketed_into=item_id, commit=False) + + if commit: + db.commit() + + return item_id + + def load_from_db(id: int) -> "Item": + db = get_db() + row = db.execute( + """SELECT raw_data, is_identified, is_socketed, + is_beginner, is_simple, is_ethereal, is_personalized, is_runeword, pos_x, + pos_y, code, uid, lvl, quality, graphic, implicit, low_quality, prefix1, + prefix2, prefix3, suffix1, suffix2, suffix3, set_id, unique_id, + nameword1, nameword2, runeword_id, personal_name, defense, durability, + max_durability, quantity + FROM item WHERE id = ?""", + (id,), + ).fetchone() + + quality = Quality(row["quality"]) if row["quality"] else None + prefixes = None + suffixes = None + if quality in [Quality.MAGIC, Quality.RARE, Quality.CRAFTED]: + prefixes = [ + x for x in [row["prefix1"], row["prefix2"], row["prefix3"]] if x + ] + suffixes = [ + x for x in [row["suffix1"], row["suffix2"], row["suffix3"]] if x + ] + + item = Item( + raw_data=row["raw_data"], + is_identified=bool(row["is_identified"]), + is_socketed=bool(row["is_socketed"]), + is_beginner=bool(row["is_beginner"]), + is_simple=bool(row["is_simple"]), + is_ethereal=bool(row["is_ethereal"]), + is_personalized=bool(row["is_personalized"]), + is_runeword=bool(row["is_runeword"]), + pos_x=row["pos_x"], + pos_y=row["pos_y"], + code=row["code"], + uid=row["uid"], + lvl=row["lvl"], + quality=quality, + graphic=row["graphic"], + implicit=row["implicit"], + low_quality=LowQualityType(row["low_quality"]) + if row["low_quality"] + else None, + prefixes=prefixes, + suffixes=suffixes, + set_id=row["set_id"], + unique_id=row["unique_id"], + nameword1=row["nameword1"], + nameword2=row["nameword2"], + runeword_id=row["runeword_id"], + personal_name=row["personal_name"], + defense=row["defense"], + durability=row["durability"], + max_durability=row["max_durability"], + quantity=row["quantity"], + ) + + rows = db.execute( + "SELECT id FROM item WHERE socketed_into = ?", (id,) + ).fetchall() + if len(rows) > 0: + item.sockets = [] + for row in rows: + socket = Item.load_from_db(row["id"]) + item.sockets.append(socket) + + rows = db.execute( + "SELECT stat, value1, value2, value3, parameter FROM item_stat WHERE item_id = ?", + (id,), + ).fetchall() + if len(rows) > 0: + item.stats = [] + for row in rows: + values = [] + for i in range(1, 4): + if row[f"value{i}"] is not None: + values.append(row[f"value{i}"]) + stat = Stat(id=row["stat"], values=values, parameter=row["parameter"]) + stat_data = lookup_stat(stat.id) + stat.text = stat_data["text"] + item.stats.append(stat) + + return item + def lookup_basetype(code: str) -> dict: global _basetype_map diff --git a/d2warehouse/schema.sql b/d2warehouse/schema.sql new file mode 100644 index 0000000..a24f34b --- /dev/null +++ b/d2warehouse/schema.sql @@ -0,0 +1,70 @@ +DROP TABLE IF EXISTS stat; +DROP TABLE IF EXISTS item; + +CREATE TABLE item ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + deletion TEXT DEFAULT NULL CHECK (deletion IN (NULL, "deleted", "nuked")), -- * + itembase_name TEXT NOT NULL, + item_name TEXT NOT NULL, + set_name TEXT DEFAULT NULL, + socketed_into INTEGER DEFAULT NULL, + + -- The following fields match the fields of the item object in item.py + raw_data BLOB NOT NULL, + is_identified BOOLEAN NOT NULL, + is_socketed BOOLEAN NOT NULL, + is_beginner BOOLEAN NOT NULL, + is_simple BOOLEAN NOT NULL, + is_ethereal BOOLEAN NOT NULL, + is_personalized BOOLEAN NOT NULL, + is_runeword BOOLEAN NOT NULL, + pos_x INTEGER NOT NULL, + pos_y INTEGER NOT NULL, + code TEXT NOT NULL, + uid INTEGER DEFAULT NULL, + lvl INTEGER DEFAULT NULL, + quality INTEGER DEFAULT NULL CHECK(1 <= quality AND quality <= 8), + graphic INTEGER DEFAULT NULL, + implicit INTEGER DEFAULT NULL, + low_quality INTEGER DEFAULT NULL CHECK(0 <= low_quality AND low_quality <= 3), + prefix1 INTEGER DEFAULT NULL, + prefix2 INTEGER DEFAULT NULL, + prefix3 INTEGER DEFAULT NULL, + suffix1 INTEGER DEFAULT NULL, + suffix2 INTEGER DEFAULT NULL, + suffix3 INTEGER DEFAULT NULL, + set_id INTEGER DEFAULT NULL, + unique_id INTEGER DEFAULT NULL, + nameword1 INTEGER DEFAULT NULL, + nameword2 INTEGER DEFAULT NULL, + runeword_id INTEGER DEFAULT NULL, + personal_name TEXT DEFAULT NULL, + defense INTEGER DEFAULT NULL, + durability INTEGER DEFAULT NULL, + max_durability INTEGER DEFAULT NULL, + -- sockets: list[Optional["Item"]] | None = None => see socketed_into + quantity INTEGER DEFAULT NULL, + -- stats: list[Stat] | None = None => see table 'item_stat' + + FOREIGN KEY (socketed_into) REFERENCES item (item_id) +); +-- Add an index for "... WHERE deletion IS NULL" +CREATE INDEX item_deletion_partial ON item (deletion) WHERE deletion IS NULL; +-- *: NULL: if the item is currently stored, +-- deleted: if the item has been removed from storage +-- nuked: if the item has been removed from storage & user indicated he does not +-- want to count it as potentially in his possession any longer + +CREATE TABLE item_stat ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + item_id INTEGER NOT NULL, + stat INTEGER NOT NULL, + value1 INTEGER DEFAULT NULL, + value2 INTEGER DEFAULT NULL, + value3 INTEGER DEFAULT NULL, + parameter INTEGER DEFAULT NULL, + + FOREIGN KEY (item_id) REFERENCES item (item_id) +); +CREATE INDEX item_stat_stat ON item_stat (stat); diff --git a/d2warehouse/tests/test_db.py b/d2warehouse/tests/test_db.py new file mode 100644 index 0000000..5b2060e --- /dev/null +++ b/d2warehouse/tests/test_db.py @@ -0,0 +1,31 @@ +import os +import tempfile +import unittest + +from d2warehouse.db import close_db, create_db, set_db_path +from d2warehouse.item import Item +from d2warehouse.parser import parse_item + + +class DbTest(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls._fd, cls._path = tempfile.mkstemp() + set_db_path(cls._path) + create_db() + + @classmethod + def tearDownClass(cls): + close_db() + os.close(cls._fd) + os.unlink(cls._path) + + def test_runeword_lore(self): + data = bytes.fromhex( + "1008800405c055c637f1073af4558697412981070881506049e87f005516fb134582ff1000a0003500e07cbb001000a0003504e07c9800" + ) + _, item = parse_item(data) + + db_id = item.write_to_db() + loaded_item = Item.load_from_db(db_id) + self.assertEqual(item, loaded_item)