diff --git a/d2warehouse/item.py b/d2warehouse/item.py index e2511f6..14240d2 100644 --- a/d2warehouse/item.py +++ b/d2warehouse/item.py @@ -17,7 +17,6 @@ import json import os import re -from typing import Optional from bitarray import bitarray from bitarray.util import int2ba from dataclasses import dataclass @@ -119,7 +118,8 @@ class Item: defense: int | None = None durability: int | None = None max_durability: int | None = None - sockets: list[Optional["Item"]] | None = None + sockets: int | None = None + socketed_items: list["Item"] | None = None quantity: int | None = None stats: list[Stat] | None = None @@ -181,10 +181,9 @@ class Item: def raw(self): parts = [self.raw_data] - if self.sockets: - for item in self.sockets: - if item: - parts.append(item.raw_data) + if self.socketed_items: + for item in self.socketed_items: + parts.append(item.raw_data) return b"".join(parts) def print(self, indent=5, with_raw=False): @@ -233,12 +232,9 @@ class Item: f"Durability: {self.durability} out of {self.max_durability}", ) if self.is_socketed: - print(" " * indent, f"{len(self.sockets)} sockets:") - for socket in self.sockets: - if socket: - socket.print(indent + 4) - else: - print(" " * (indent + 4), "Empty") + print(" " * indent, f"{len(self.socketed_items)}/{self.sockets} sockets:") + for socket in self.socketed_items: + socket.print(indent + 4) if self.quantity: print(" " * indent, f"Quantity: {self.quantity}") if self.stats: @@ -296,9 +292,9 @@ class Item: reqs["lvl"] = max(reqs["lvl"], affix["req_lvl"]) if affix["req_class"]: reqs["class"] = affix["req_class"] - if self.sockets: - for socket in self.sockets: - socket_reqs = socket.requirements() + if self.socketed_items: + for socket_item in self.socketed_items: + socket_reqs = socket_item.requirements() reqs["lvl"] = max(reqs["lvl"], socket_reqs["lvl"]) return reqs @@ -348,9 +344,10 @@ class Item: """INSERT INTO item_extra (item_id, item_name, set_name, uid, lvl, quality, graphic, implicit, low_quality, set_id, unique_id, nameword1, nameword2, runeword_id, personal_name, - defense, durability, max_durability, quantity, req_str, req_dex, - req_class) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + defense, durability, max_durability, sockets, quantity, req_str, + req_dex, req_class) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, + ?, ?, ?)""", ( item_id, name, @@ -370,6 +367,7 @@ class Item: self.defense, self.durability, self.max_durability, + self.sockets, self.quantity, req["str"], req["dex"], @@ -402,9 +400,9 @@ class Item: ), ) - if self.sockets: - for socket in self.sockets: - socket.write_to_db(socketed_into=item_id, commit=False) + if self.socketed_items: + for socket_item in self.socketed_items: + socket_item.write_to_db(socketed_into=item_id, commit=False, db=db) if commit: db.commit() @@ -419,7 +417,7 @@ class Item: is_beginner, is_simple, is_ethereal, is_personalized, is_runeword, code, uid, lvl, quality, graphic, implicit, low_quality, set_id, unique_id, nameword1, nameword2, runeword_id, personal_name, defense, durability, - max_durability, quantity + max_durability, sockets, quantity FROM item LEFT JOIN item_extra ON id = item_id WHERE id = ?""", (id,), ).fetchone() @@ -456,6 +454,7 @@ class Item: defense=row["defense"], durability=row["durability"], max_durability=row["max_durability"], + sockets=row["sockets"], quantity=row["quantity"], ) @@ -471,15 +470,16 @@ class Item: else: item.suffixes.append(row["affix_id"]) - rows = db.execute( - "SELECT id FROM item WHERE socketed_into = ?", (id,) - ).fetchall() - if len(rows) > 0: - item.sockets = [] - for row in rows: - socket = Item.load_from_db(row["id"]) - socket.pos_x = len(item.sockets) - item.sockets.append(socket) + if item.is_socketed: + item.socketed_items = [] + rows = db.execute( + "SELECT id FROM item WHERE socketed_into = ?", (id,) + ).fetchall() + if len(rows) > 0: + for row in rows: + socket_item = Item.load_from_db(row["id"], db=db) + socket_item.pos_x = len(item.socketed_items) + item.socketed_items.append(socket_item) rows = db.execute( "SELECT stat, value1, value2, value3, parameter FROM item_stat WHERE item_id = ?", diff --git a/d2warehouse/parser.py b/d2warehouse/parser.py index d499611..d96ad83 100644 --- a/d2warehouse/parser.py +++ b/d2warehouse/parser.py @@ -179,7 +179,8 @@ def parse_item(data: bytes) -> tuple[bytes, Item]: itembase_end += personalized_end if item.is_socketed: - sockets_count = ba2int(bits[itembase_end : itembase_end + 4]) + item.sockets = ba2int(bits[itembase_end : itembase_end + 4]) + item.socketed_items = [] sockets_end = itembase_end + 4 else: sockets_end = itembase_end @@ -194,10 +195,9 @@ def parse_item(data: bytes) -> tuple[bytes, Item]: # Parse out sockets if any exist on the item if item.is_socketed: - item.sockets = [None] * sockets_count for i in range(0, filled_sockets): - remaining_data, socket = parse_item(remaining_data) - item.sockets[i] = socket + remaining_data, socket_item = parse_item(remaining_data) + item.socketed_items.append(socket_item) return remaining_data, item diff --git a/d2warehouse/schema.sql b/d2warehouse/schema.sql index 5f3400b..ba7f9bf 100644 --- a/d2warehouse/schema.sql +++ b/d2warehouse/schema.sql @@ -59,7 +59,7 @@ CREATE TABLE item_extra ( defense INTEGER DEFAULT NULL, durability INTEGER DEFAULT NULL, max_durability INTEGER DEFAULT NULL, - -- sockets: list[Optional["Item"]] | None = None => see item.socketed_into + sockets INTEGER DEFAULT NULL, -- number of sockets; see item.socketed_into quantity INTEGER DEFAULT NULL, -- stats: list[Stat] | None = None => see table 'item_stat' diff --git a/d2warehouse/tests/test_db.py b/d2warehouse/tests/test_db.py index d27f24c..a020d8d 100644 --- a/d2warehouse/tests/test_db.py +++ b/d2warehouse/tests/test_db.py @@ -26,17 +26,30 @@ class DbTest(unittest.TestCase): ) _, item = parse_item(data) - db_id = item.write_to_db() - loaded_item = Item.load_from_db(db_id) + db = get_db() + db_id = item.write_to_db(db=db) + loaded_item = Item.load_from_db(db_id, db=db) self.assertEqual(item, loaded_item) # Check that requirement was written properly - db = get_db() reqs = db.execute( - "SELECT req_lvl, req_str, req_dex, req_class FROM item JOIN item_extra ON id = item_id WHERE id = 1" + "SELECT req_lvl, req_str, req_dex, req_class FROM item JOIN item_extra ON id = item_id WHERE id = ?", + (db_id,), ).fetchone() expected_reqs = item.requirements() self.assertEqual(reqs["req_lvl"], expected_reqs["lvl"]) self.assertEqual(reqs["req_str"], expected_reqs["str"]) self.assertEqual(reqs["req_dex"], expected_reqs["dex"]) self.assertEqual(reqs["req_class"], expected_reqs["class"]) + + def test_empty_sockets(self): + # superior armor with empty sockets + data = bytes.fromhex("10088000050014df175043b1b90cc38d80e3834070b004f41f") + _, item = parse_item(data) + + db = get_db() + db_id = item.write_to_db(db=db) + loaded_item = Item.load_from_db(db_id, db=db) + self.assertEqual(loaded_item.sockets, 2) + self.assertEqual(len(loaded_item.socketed_items), 0) + self.assertEqual(item, loaded_item) diff --git a/d2warehouse/tests/test_parse_item.py b/d2warehouse/tests/test_parse_item.py index 0c13d65..ecb2c5b 100644 --- a/d2warehouse/tests/test_parse_item.py +++ b/d2warehouse/tests/test_parse_item.py @@ -68,7 +68,8 @@ class ParseItemTest(unittest.TestCase): self.assertEqual(data, b"") self.assertEqual(item.quality, Quality.HIGH) self.assertEqual(len(item.stats), 2) - self.assertEqual(len(item.sockets), 2) + self.assertEqual(item.sockets, 2) + self.assertEqual(len(item.socketed_items), 0) def test_ed_max(self): # test bugfix for https://gitlab.com/omicron-oss/d2warehouse/-/issues/1 @@ -89,11 +90,9 @@ class ParseItemTest(unittest.TestCase): data, item = parse_item(data) self.assertEqual(data, b"") self.assertTrue(item.is_runeword) - self.assertEqual(len(item.sockets), 2) - item.sockets[0].print() - item.sockets[1].print() - self.assertEqual(item.sockets[0].code, "r09") - self.assertEqual(item.sockets[1].code, "r12") + self.assertEqual(item.sockets, 2) + self.assertEqual(item.socketed_items[0].code, "r09") + self.assertEqual(item.socketed_items[1].code, "r12") rw = lookup_runeword(item.runeword_id) self.assertEqual(rw["name"], "Lore") self.assertEqual(str(item.stats[4]), "+1 to All Skills") # runeword stat