Split sockets into sockets and socketed_items

This commit is contained in:
2023-10-28 18:43:35 +02:00
committed by omicron
parent 423e4368d7
commit 21df28d7bb
5 changed files with 58 additions and 46 deletions

View File

@@ -17,7 +17,6 @@
import json import json
import os import os
import re import re
from typing import Optional
from bitarray import bitarray from bitarray import bitarray
from bitarray.util import int2ba from bitarray.util import int2ba
from dataclasses import dataclass from dataclasses import dataclass
@@ -119,7 +118,8 @@ class Item:
defense: int | None = None defense: int | None = None
durability: int | None = None durability: int | None = None
max_durability: int | None = None max_durability: int | None = None
sockets: list[Optional["Item"]] | None = None sockets: int | None = None
socketed_items: list["Item"] | None = None
quantity: int | None = None quantity: int | None = None
stats: list[Stat] | None = None stats: list[Stat] | None = None
@@ -181,10 +181,9 @@ class Item:
def raw(self): def raw(self):
parts = [self.raw_data] parts = [self.raw_data]
if self.sockets: if self.socketed_items:
for item in self.sockets: for item in self.socketed_items:
if item: parts.append(item.raw_data)
parts.append(item.raw_data)
return b"".join(parts) return b"".join(parts)
def print(self, indent=5, with_raw=False): def print(self, indent=5, with_raw=False):
@@ -233,12 +232,9 @@ class Item:
f"Durability: {self.durability} out of {self.max_durability}", f"Durability: {self.durability} out of {self.max_durability}",
) )
if self.is_socketed: if self.is_socketed:
print(" " * indent, f"{len(self.sockets)} sockets:") print(" " * indent, f"{len(self.socketed_items)}/{self.sockets} sockets:")
for socket in self.sockets: for socket in self.socketed_items:
if socket: socket.print(indent + 4)
socket.print(indent + 4)
else:
print(" " * (indent + 4), "Empty")
if self.quantity: if self.quantity:
print(" " * indent, f"Quantity: {self.quantity}") print(" " * indent, f"Quantity: {self.quantity}")
if self.stats: if self.stats:
@@ -296,9 +292,9 @@ class Item:
reqs["lvl"] = max(reqs["lvl"], affix["req_lvl"]) reqs["lvl"] = max(reqs["lvl"], affix["req_lvl"])
if affix["req_class"]: if affix["req_class"]:
reqs["class"] = affix["req_class"] reqs["class"] = affix["req_class"]
if self.sockets: if self.socketed_items:
for socket in self.sockets: for socket_item in self.socketed_items:
socket_reqs = socket.requirements() socket_reqs = socket_item.requirements()
reqs["lvl"] = max(reqs["lvl"], socket_reqs["lvl"]) reqs["lvl"] = max(reqs["lvl"], socket_reqs["lvl"])
return reqs return reqs
@@ -348,9 +344,10 @@ class Item:
"""INSERT INTO item_extra (item_id, item_name, """INSERT INTO item_extra (item_id, item_name,
set_name, uid, lvl, quality, graphic, implicit, low_quality, set_id, set_name, uid, lvl, quality, graphic, implicit, low_quality, set_id,
unique_id, nameword1, nameword2, runeword_id, personal_name, unique_id, nameword1, nameword2, runeword_id, personal_name,
defense, durability, max_durability, quantity, req_str, req_dex, defense, durability, max_durability, sockets, quantity, req_str,
req_class) req_dex, req_class)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?,
?, ?, ?)""",
( (
item_id, item_id,
name, name,
@@ -370,6 +367,7 @@ class Item:
self.defense, self.defense,
self.durability, self.durability,
self.max_durability, self.max_durability,
self.sockets,
self.quantity, self.quantity,
req["str"], req["str"],
req["dex"], req["dex"],
@@ -402,9 +400,9 @@ class Item:
), ),
) )
if self.sockets: if self.socketed_items:
for socket in self.sockets: for socket_item in self.socketed_items:
socket.write_to_db(socketed_into=item_id, commit=False) socket_item.write_to_db(socketed_into=item_id, commit=False, db=db)
if commit: if commit:
db.commit() db.commit()
@@ -419,7 +417,7 @@ class Item:
is_beginner, is_simple, is_ethereal, is_personalized, is_runeword, code, is_beginner, is_simple, is_ethereal, is_personalized, is_runeword, code,
uid, lvl, quality, graphic, implicit, low_quality, set_id, unique_id, uid, lvl, quality, graphic, implicit, low_quality, set_id, unique_id,
nameword1, nameword2, runeword_id, personal_name, defense, durability, nameword1, nameword2, runeword_id, personal_name, defense, durability,
max_durability, quantity max_durability, sockets, quantity
FROM item LEFT JOIN item_extra ON id = item_id WHERE id = ?""", FROM item LEFT JOIN item_extra ON id = item_id WHERE id = ?""",
(id,), (id,),
).fetchone() ).fetchone()
@@ -456,6 +454,7 @@ class Item:
defense=row["defense"], defense=row["defense"],
durability=row["durability"], durability=row["durability"],
max_durability=row["max_durability"], max_durability=row["max_durability"],
sockets=row["sockets"],
quantity=row["quantity"], quantity=row["quantity"],
) )
@@ -471,15 +470,16 @@ class Item:
else: else:
item.suffixes.append(row["affix_id"]) item.suffixes.append(row["affix_id"])
rows = db.execute( if item.is_socketed:
"SELECT id FROM item WHERE socketed_into = ?", (id,) item.socketed_items = []
).fetchall() rows = db.execute(
if len(rows) > 0: "SELECT id FROM item WHERE socketed_into = ?", (id,)
item.sockets = [] ).fetchall()
for row in rows: if len(rows) > 0:
socket = Item.load_from_db(row["id"]) for row in rows:
socket.pos_x = len(item.sockets) socket_item = Item.load_from_db(row["id"], db=db)
item.sockets.append(socket) socket_item.pos_x = len(item.socketed_items)
item.socketed_items.append(socket_item)
rows = db.execute( rows = db.execute(
"SELECT stat, value1, value2, value3, parameter FROM item_stat WHERE item_id = ?", "SELECT stat, value1, value2, value3, parameter FROM item_stat WHERE item_id = ?",

View File

@@ -179,7 +179,8 @@ def parse_item(data: bytes) -> tuple[bytes, Item]:
itembase_end += personalized_end itembase_end += personalized_end
if item.is_socketed: if item.is_socketed:
sockets_count = ba2int(bits[itembase_end : itembase_end + 4]) item.sockets = ba2int(bits[itembase_end : itembase_end + 4])
item.socketed_items = []
sockets_end = itembase_end + 4 sockets_end = itembase_end + 4
else: else:
sockets_end = itembase_end sockets_end = itembase_end
@@ -194,10 +195,9 @@ def parse_item(data: bytes) -> tuple[bytes, Item]:
# Parse out sockets if any exist on the item # Parse out sockets if any exist on the item
if item.is_socketed: if item.is_socketed:
item.sockets = [None] * sockets_count
for i in range(0, filled_sockets): for i in range(0, filled_sockets):
remaining_data, socket = parse_item(remaining_data) remaining_data, socket_item = parse_item(remaining_data)
item.sockets[i] = socket item.socketed_items.append(socket_item)
return remaining_data, item return remaining_data, item

View File

@@ -59,7 +59,7 @@ CREATE TABLE item_extra (
defense INTEGER DEFAULT NULL, defense INTEGER DEFAULT NULL,
durability INTEGER DEFAULT NULL, durability INTEGER DEFAULT NULL,
max_durability INTEGER DEFAULT NULL, max_durability INTEGER DEFAULT NULL,
-- sockets: list[Optional["Item"]] | None = None => see item.socketed_into sockets INTEGER DEFAULT NULL, -- number of sockets; see item.socketed_into
quantity INTEGER DEFAULT NULL, quantity INTEGER DEFAULT NULL,
-- stats: list[Stat] | None = None => see table 'item_stat' -- stats: list[Stat] | None = None => see table 'item_stat'

View File

@@ -26,17 +26,30 @@ class DbTest(unittest.TestCase):
) )
_, item = parse_item(data) _, item = parse_item(data)
db_id = item.write_to_db() db = get_db()
loaded_item = Item.load_from_db(db_id) db_id = item.write_to_db(db=db)
loaded_item = Item.load_from_db(db_id, db=db)
self.assertEqual(item, loaded_item) self.assertEqual(item, loaded_item)
# Check that requirement was written properly # Check that requirement was written properly
db = get_db()
reqs = db.execute( reqs = db.execute(
"SELECT req_lvl, req_str, req_dex, req_class FROM item JOIN item_extra ON id = item_id WHERE id = 1" "SELECT req_lvl, req_str, req_dex, req_class FROM item JOIN item_extra ON id = item_id WHERE id = ?",
(db_id,),
).fetchone() ).fetchone()
expected_reqs = item.requirements() expected_reqs = item.requirements()
self.assertEqual(reqs["req_lvl"], expected_reqs["lvl"]) self.assertEqual(reqs["req_lvl"], expected_reqs["lvl"])
self.assertEqual(reqs["req_str"], expected_reqs["str"]) self.assertEqual(reqs["req_str"], expected_reqs["str"])
self.assertEqual(reqs["req_dex"], expected_reqs["dex"]) self.assertEqual(reqs["req_dex"], expected_reqs["dex"])
self.assertEqual(reqs["req_class"], expected_reqs["class"]) self.assertEqual(reqs["req_class"], expected_reqs["class"])
def test_empty_sockets(self):
# superior armor with empty sockets
data = bytes.fromhex("10088000050014df175043b1b90cc38d80e3834070b004f41f")
_, item = parse_item(data)
db = get_db()
db_id = item.write_to_db(db=db)
loaded_item = Item.load_from_db(db_id, db=db)
self.assertEqual(loaded_item.sockets, 2)
self.assertEqual(len(loaded_item.socketed_items), 0)
self.assertEqual(item, loaded_item)

View File

@@ -68,7 +68,8 @@ class ParseItemTest(unittest.TestCase):
self.assertEqual(data, b"") self.assertEqual(data, b"")
self.assertEqual(item.quality, Quality.HIGH) self.assertEqual(item.quality, Quality.HIGH)
self.assertEqual(len(item.stats), 2) self.assertEqual(len(item.stats), 2)
self.assertEqual(len(item.sockets), 2) self.assertEqual(item.sockets, 2)
self.assertEqual(len(item.socketed_items), 0)
def test_ed_max(self): def test_ed_max(self):
# test bugfix for https://gitlab.com/omicron-oss/d2warehouse/-/issues/1 # test bugfix for https://gitlab.com/omicron-oss/d2warehouse/-/issues/1
@@ -89,11 +90,9 @@ class ParseItemTest(unittest.TestCase):
data, item = parse_item(data) data, item = parse_item(data)
self.assertEqual(data, b"") self.assertEqual(data, b"")
self.assertTrue(item.is_runeword) self.assertTrue(item.is_runeword)
self.assertEqual(len(item.sockets), 2) self.assertEqual(item.sockets, 2)
item.sockets[0].print() self.assertEqual(item.socketed_items[0].code, "r09")
item.sockets[1].print() self.assertEqual(item.socketed_items[1].code, "r12")
self.assertEqual(item.sockets[0].code, "r09")
self.assertEqual(item.sockets[1].code, "r12")
rw = lookup_runeword(item.runeword_id) rw = lookup_runeword(item.runeword_id)
self.assertEqual(rw["name"], "Lore") self.assertEqual(rw["name"], "Lore")
self.assertEqual(str(item.stats[4]), "+1 to All Skills") # runeword stat self.assertEqual(str(item.stats[4]), "+1 to All Skills") # runeword stat