Add basic sqlite3 saving & loading support

This commit is contained in:
2023-10-27 13:48:59 +00:00
committed by Omicron
parent 4da8e096fe
commit 40276cb0b4
7 changed files with 354 additions and 7 deletions

38
d2warehouse/db.py Normal file
View File

@@ -0,0 +1,38 @@
import atexit
import os
import sqlite3
_schema_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "schema.sql")
_db = None
_path = "items.sqlite3"
def set_db_path(f: str) -> None:
global _path
_path = f
def get_db() -> sqlite3.Connection:
global _db
if not _db:
_db = sqlite3.connect(
_path,
detect_types=sqlite3.PARSE_DECLTYPES,
)
_db.row_factory = sqlite3.Row
return _db
@atexit.register
def close_db() -> None:
global _db
if _db:
_db.close()
_db = None
def create_db() -> None:
db = get_db()
with open(_schema_path, encoding="utf-8") as f:
db.executescript(f.read())

View File

@@ -16,4 +16,5 @@
# Mercator. If not, see <https://www.gnu.org/licenses/>.
STASH_TAB_MAGIC = b"\x55\xAA\x55\xAA"
STASH_TAB_VERSION = 99
ITEM_DATA_MAGIC = b"JM"

View File

@@ -21,7 +21,9 @@ from typing import Optional
from bitarray import bitarray
from bitarray.util import int2ba
from dataclasses import dataclass
from enum import Enum
from enum import IntEnum
from d2warehouse.fileformat import STASH_TAB_VERSION
from d2warehouse.db import get_db
_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
_basetype_map = None
@@ -31,7 +33,7 @@ _set_item_map = None
_runeword_map = None
class Quality(Enum):
class Quality(IntEnum):
LOW = 1
NORMAL = 2
HIGH = 3
@@ -45,7 +47,7 @@ class Quality(Enum):
return self.name.capitalize()
class LowQualityType(Enum):
class LowQualityType(IntEnum):
CRUDE = 0
CRACKED = 1
DAMAGED = 2
@@ -60,7 +62,7 @@ class Stat:
id: int | None = None # TODO: These 3 should probably not be optional
values: list[int] | None = None
parameter: int | None = None
text: str | None = None
text: str | None = None # TODO: Make this a property
def print(self, indent=5):
print(" " * indent, str(self))
@@ -88,6 +90,7 @@ def txtbits(bits: bitarray) -> str:
@dataclass
class Item:
raw_data: bytes
raw_version: int
is_identified: bool
is_socketed: bool
is_beginner: bool
@@ -201,6 +204,194 @@ class Item:
base = lookup_basetype(self.code)
return base["width"], base["height"]
def write_to_db(self, socketed_into=None, commit=True) -> int:
name = lookup_basetype(self.code)["name"]
# FIXME: handle magic & rare names
if self.is_runeword:
name = lookup_runeword(self.runeword_id)["name"]
elif self.quality == Quality.SET:
name = lookup_set_item(self.set_id)["name"]
elif self.quality == Quality.UNIQUE:
name = lookup_unique(self.unique_id)["name"]
set_name = (
lookup_set_item(self.set_id)["set"] if self.quality == Quality.SET else None
)
db = get_db()
cur = db.cursor()
cur.execute(
"""INSERT INTO item (itembase_name, socketed_into, raw_data, raw_version,
is_identified, is_socketed, is_beginner, is_simple, is_ethereal,
is_personalized, is_runeword, code)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
lookup_basetype(self.code)["name"],
socketed_into,
self.raw_data,
self.raw_version,
self.is_identified,
self.is_socketed,
self.is_beginner,
self.is_simple,
self.is_ethereal,
self.is_personalized,
self.is_runeword,
self.code,
),
)
item_id = cur.lastrowid
cur.execute(
"""INSERT INTO item_extra (item_id, item_name,
set_name, uid, lvl, quality, graphic, implicit, low_quality, set_id,
unique_id, nameword1, nameword2, runeword_id, personal_name,
defense, durability, max_durability, quantity)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
item_id,
name,
set_name,
self.uid,
self.lvl,
int(self.quality) if self.quality else None,
self.graphic,
self.implicit,
int(self.low_quality) if self.low_quality else None,
self.set_id,
self.unique_id,
self.nameword1,
self.nameword2,
self.runeword_id,
self.personal_name,
self.defense,
self.durability,
self.max_durability,
self.quantity,
),
)
if self.quality in [Quality.MAGIC, Quality.RARE, Quality.CRAFTED]:
for prefix, id in [(False, id) for id in self.suffixes] + [
(True, id) for id in self.prefixes
]:
db.execute(
"INSERT INTO item_affix (item_id, prefix, affix_id) VALUES (?, ?, ?)",
(item_id, prefix, id),
)
if self.stats:
for stat in self.stats:
db.execute(
"""INSERT INTO item_stat (item_id, stat, value1, value2, value3, parameter)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
item_id,
stat.id,
stat.values[0] if len(stat.values) > 0 else None,
stat.values[1] if len(stat.values) > 1 else None,
stat.values[2] if len(stat.values) > 2 else None,
stat.parameter,
),
)
if self.sockets:
for socket in self.sockets:
socket.write_to_db(socketed_into=item_id, commit=False)
if commit:
db.commit()
return item_id
def load_from_db(id: int) -> "Item":
db = get_db()
row = db.execute(
"""SELECT raw_data, raw_version, is_identified, is_socketed,
is_beginner, is_simple, is_ethereal, is_personalized, is_runeword, code,
uid, lvl, quality, graphic, implicit, low_quality, set_id, unique_id,
nameword1, nameword2, runeword_id, personal_name, defense, durability,
max_durability, quantity
FROM item INNER JOIN item_extra ON id = item_id WHERE id = ?""",
(id,),
).fetchone()
if row["raw_version"] != STASH_TAB_VERSION:
raise RuntimeError("Can not load item, the raw version is not supported")
item = Item(
raw_data=row["raw_data"],
raw_version=row["raw_version"],
is_identified=bool(row["is_identified"]),
is_socketed=bool(row["is_socketed"]),
is_beginner=bool(row["is_beginner"]),
is_simple=bool(row["is_simple"]),
is_ethereal=bool(row["is_ethereal"]),
is_personalized=bool(row["is_personalized"]),
is_runeword=bool(row["is_runeword"]),
pos_x=0,
pos_y=0,
code=row["code"],
uid=row["uid"],
lvl=row["lvl"],
quality=Quality(row["quality"]) if row["quality"] else None,
graphic=row["graphic"],
implicit=row["implicit"],
low_quality=LowQualityType(row["low_quality"])
if row["low_quality"]
else None,
set_id=row["set_id"],
unique_id=row["unique_id"],
nameword1=row["nameword1"],
nameword2=row["nameword2"],
runeword_id=row["runeword_id"],
personal_name=row["personal_name"],
defense=row["defense"],
durability=row["durability"],
max_durability=row["max_durability"],
quantity=row["quantity"],
)
if item.quality in [Quality.MAGIC, Quality.RARE, Quality.CRAFTED]:
rows = db.execute(
"SELECT prefix, affix_id FROM item_affix WHERE item_id = ?", (id,)
)
item.prefixes = []
item.suffixes = []
for row in rows:
if row["prefix"]:
item.prefixes.append(row["affix_id"])
else:
item.suffixes.append(row["affix_id"])
rows = db.execute(
"SELECT id FROM item WHERE socketed_into = ?", (id,)
).fetchall()
if len(rows) > 0:
item.sockets = []
for row in rows:
socket = Item.load_from_db(row["id"])
socket.pos_x = len(item.sockets)
item.sockets.append(socket)
rows = db.execute(
"SELECT stat, value1, value2, value3, parameter FROM item_stat WHERE item_id = ?",
(id,),
).fetchall()
if len(rows) > 0:
item.stats = []
for row in rows:
values = []
for i in range(1, 4):
if row[f"value{i}"] is not None:
values.append(row[f"value{i}"])
stat = Stat(id=row["stat"], values=values, parameter=row["parameter"])
stat_data = lookup_stat(stat.id)
stat.text = stat_data["text"]
item.stats.append(stat)
return item
def lookup_basetype(code: str) -> dict:
global _basetype_map

View File

@@ -27,7 +27,7 @@ from d2warehouse.item import (
lookup_stat,
)
import d2warehouse.huffman as huffman
from d2warehouse.fileformat import STASH_TAB_MAGIC, ITEM_DATA_MAGIC
from d2warehouse.fileformat import STASH_TAB_MAGIC, STASH_TAB_VERSION, ITEM_DATA_MAGIC
class ParseError(RuntimeError):
@@ -82,7 +82,7 @@ def parse_stash_tab(data: bytes) -> tuple[bytes, StashTab]:
if unknown != 1:
ParseError("Unknown stash tab field is not 1")
if version != 99:
if version != STASH_TAB_VERSION:
ParseError(f"Unsupported stash tab version ({version} instead of 99)")
tab = StashTab()
@@ -133,6 +133,7 @@ def parse_item(data: bytes) -> tuple[bytes, Item]:
simple_byte_sz = int((sockets_end + 7) / 8)
item = Item(
data[:simple_byte_sz],
STASH_TAB_VERSION,
is_identified,
is_socketed,
is_beginner,

85
d2warehouse/schema.sql Normal file
View File

@@ -0,0 +1,85 @@
DROP TABLE IF EXISTS item_stat;
DROP TABLE IF EXISTS item_affix;
DROP TABLE IF EXISTS item_extra;
DROP TABLE IF EXISTS item;
CREATE TABLE item (
id INTEGER PRIMARY KEY,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted TIMESTAMP DEFAULT NULL,
-- Nuked: if the item has been removed from storage & user indicated he does not
-- want to count it as potentially in his possession any longer
nuked BOOLEAN DEFAULT NULL,
-- Names: simple items have only a name equal to their base name, most non simple
-- items have two names: the base name and item_extra.item_name.
itembase_name TEXT NOT NULL,
socketed_into INTEGER DEFAULT NULL,
-- The following fields match the fields of the item object in item.py
raw_data BLOB NOT NULL,
raw_version INTEGER NOT NULL,
is_identified BOOLEAN NOT NULL,
is_socketed BOOLEAN NOT NULL,
is_beginner BOOLEAN NOT NULL,
is_simple BOOLEAN NOT NULL,
is_ethereal BOOLEAN NOT NULL,
is_personalized BOOLEAN NOT NULL,
is_runeword BOOLEAN NOT NULL,
code TEXT NOT NULL,
FOREIGN KEY (socketed_into) REFERENCES item (id)
);
-- Add an index for "... WHERE deletion IS NULL"
CREATE INDEX item_deletion_partial ON item (deleted) WHERE deleted IS NULL;
-- * nuked: if the item has been removed from storage & user indicated he does not
-- want to count it as potentially in his possession any longer
CREATE TABLE item_extra (
item_id INTEGER PRIMARY KEY,
item_name TEXT DEFAULT NULL,
set_name TEXT DEFAULT NULL,
-- The following fields match the fields of the item object in item.py
uid INTEGER DEFAULT NULL,
lvl INTEGER DEFAULT NULL,
quality INTEGER DEFAULT NULL CHECK(1 <= quality AND quality <= 8),
graphic INTEGER DEFAULT NULL,
implicit INTEGER DEFAULT NULL,
low_quality INTEGER DEFAULT NULL CHECK(0 <= low_quality AND low_quality <= 3),
set_id INTEGER DEFAULT NULL,
unique_id INTEGER DEFAULT NULL,
nameword1 INTEGER DEFAULT NULL,
nameword2 INTEGER DEFAULT NULL,
runeword_id INTEGER DEFAULT NULL,
personal_name TEXT DEFAULT NULL,
defense INTEGER DEFAULT NULL,
durability INTEGER DEFAULT NULL,
max_durability INTEGER DEFAULT NULL,
-- sockets: list[Optional["Item"]] | None = None => see item.socketed_into
quantity INTEGER DEFAULT NULL,
-- stats: list[Stat] | None = None => see table 'item_stat'
FOREIGN KEY (item_id) REFERENCES item (id)
);
CREATE TABLE item_stat (
id INTEGER PRIMARY KEY,
item_id INTEGER NOT NULL,
stat INTEGER NOT NULL,
value1 INTEGER DEFAULT NULL,
value2 INTEGER DEFAULT NULL,
value3 INTEGER DEFAULT NULL,
parameter INTEGER DEFAULT NULL,
FOREIGN KEY (item_id) REFERENCES item (id)
);
CREATE INDEX item_stat_stat ON item_stat (stat);
CREATE TABLE item_affix (
id INTEGER PRIMARY KEY,
item_id INTEGER NOT NULL,
prefix BOOLEAN NOT NULL,
affix_id INTEGER NOT NULL,
FOREIGN KEY (item_id) REFERENCES item (id)
);

View File

@@ -0,0 +1,31 @@
import os
import tempfile
import unittest
from d2warehouse.db import close_db, create_db, set_db_path
from d2warehouse.item import Item
from d2warehouse.parser import parse_item
class DbTest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls._fd, cls._path = tempfile.mkstemp()
set_db_path(cls._path)
create_db()
@classmethod
def tearDownClass(cls):
close_db()
os.close(cls._fd)
os.unlink(cls._path)
def test_runeword_lore(self):
data = bytes.fromhex(
"10088004050054c637f1073af4558697412981070881506049e87f005516fb134582ff1000a0003500e07cbb001000a0003504e07c9800"
)
_, item = parse_item(data)
db_id = item.write_to_db()
loaded_item = Item.load_from_db(db_id)
self.assertEqual(item, loaded_item)

View File

@@ -44,7 +44,7 @@ d2test = "d2warehouse.test:main"
version = {attr = "d2warehouse.__version__"}
[tool.setuptools.package-data]
d2warehouse = ["data/*.json"]
d2warehouse = ["data/*.json", "schema.sql"]
[tool.pytest.ini_options]
addopts = "--cov --cov-report html --cov-report term"