5 Commits

14 changed files with 348 additions and 46 deletions

View File

@@ -0,0 +1 @@
from d2warehouse.app.main import app

22
d2warehouse/app/db.py Normal file
View File

@@ -0,0 +1,22 @@
import sqlite3
from flask import g
import d2warehouse.db as base_db
def get_db():
if "db" not in g:
print("\n==========\nDB PATH", base_db._path, "\n============\n")
g.db = sqlite3.connect(
base_db._path,
detect_types=sqlite3.PARSE_DECLTYPES,
)
g.db.row_factory = sqlite3.Row
return g.db
def close_db(e=None):
db = g.pop("db", None)
if db is not None:
db.close()

164
d2warehouse/app/main.py Normal file
View File

@@ -0,0 +1,164 @@
import hashlib
from flask import Flask, redirect, abort, render_template, request
from pathlib import Path
from d2warehouse.item import Item
from d2warehouse.parser import parse_stash
import d2warehouse.db as base_db
from d2warehouse.app.db import get_db, close_db
import os
import re
from d2warehouse.stash import StashFullError
STASH_FILES = {
"softcore": "SharedStashSoftCoreV2.d2i",
"hardcore": "SharedStashHardCoreV2.d2i",
}
DB_FILES = {
"softcore": "d2warehouse.softcore.sqlite3",
"hardcore": "d2warehouse.hardcore.sqlite3",
}
def save_path() -> Path:
if "D2SAVE_PATH" in os.environ:
path = Path(os.environ["D2SAVE_PATH"])
else:
path = Path.home() / "Saved Games/Diablo II Resurrected"
if not path.exists():
raise RuntimeError(f"Save path `{path}` does not exist")
return path
def get_stash_db(stash):
base_db.set_db_path(str(save_path() / DB_FILES[stash]))
return get_db()
base_db.set_db_path(str(save_path() / DB_FILES["softcore"]))
base_db.init_db()
base_db.close_db()
base_db.set_db_path(str(save_path() / DB_FILES["hardcore"]))
base_db.init_db()
base_db.close_db()
app = Flask(__name__)
app.teardown_appcontext(close_db)
@app.route("/")
def home():
return redirect("/stash/softcore", code=302)
@app.route("/stash/<stash_name>")
def list_stash(stash_name: str):
if stash_name not in STASH_FILES:
abort(404)
path = save_path() / STASH_FILES[stash_name]
stash_data = path.read_bytes()
stash_hash = hashlib.sha256(stash_data).hexdigest()
stash = parse_stash(stash_data)
return render_template(
"list_stash.html", stash_name=stash_name, stash=stash, stash_hash=stash_hash
)
@app.route("/stash/<stash_name>/store", methods=["POST"])
def stash_store_items(stash_name: str):
if stash_name not in STASH_FILES or stash_name not in DB_FILES:
abort(404)
stash_path = save_path() / STASH_FILES[stash_name]
tmp_path = save_path() / f"{STASH_FILES[stash_name]}.temp"
if tmp_path.exists():
# TODO: Handle this condition
return "temp file exists (BAD)"
return 500
stash_data = stash_path.read_bytes()
stash_hash = hashlib.sha256(stash_data).hexdigest()
if request.form.get("stash_hash") != stash_hash:
return "wrong stash hash", 400
stash = parse_stash(stash_data)
items = []
locs = [y for x in request.form.keys() if (y := re.match(r"item_(\d+)_(\d+)", x))]
for item_location in locs:
tab_idx, item_idx = int(item_location.group(1)), int(item_location.group(2))
if tab_idx > len(stash.tabs) or item_idx > len(stash.tabs[tab_idx].items):
# TODO: Handle this condition
return "invalid position (2)"
item = stash.tabs[tab_idx].items[item_idx]
items.append((tab_idx, item))
# TODO: create backups
for tab_idx, item in items:
stash.tabs[tab_idx].remove(item)
tmp_path.write_bytes(stash.raw())
db = get_stash_db(stash_name)
for _, item in items:
item.write_to_db(db=db)
tmp_path.replace(stash_path)
return redirect(f"/stash/{stash_name}", code=303)
@app.route("/storage/<stash_name>")
def list_storage(stash_name: str):
if stash_name not in DB_FILES:
abort(404)
db = get_stash_db(stash_name)
items = {}
rows = db.execute("SELECT id FROM item WHERE deleted IS NULL").fetchall()
for row in rows:
items[row["id"]] = Item.load_from_db(row["id"], db=db)
return render_template(
"list_storage.html", stash_name=stash_name, storage_items=items
)
@app.route("/storage/<stash_name>/take", methods=["POST"])
def storage_take_items(stash_name: str):
if stash_name not in STASH_FILES or stash_name not in DB_FILES:
abort(404)
stash_path = save_path() / STASH_FILES[stash_name]
tmp_path = save_path() / f"{STASH_FILES[stash_name]}.temp"
if tmp_path.exists():
# TODO: Handle this condition
return "temp file exists (BAD)"
return 500
stash_data = stash_path.read_bytes()
stash = parse_stash(stash_data)
# TODO: create backups
# Write items to temporary stash file
db = get_stash_db(stash_name)
ids = [y.group(1) for x in request.form.keys() if (y := re.match(r"item_(\d+)", x))]
for id in ids:
item = Item.load_from_db(id, db=db)
try:
stash.add(item)
except StashFullError:
return "the shared stash does not fit those items", 500
tmp_path.write_bytes(stash.raw())
# Remove items from db
for id in ids:
db.execute("UPDATE item SET deleted = CURRENT_TIMESTAMP WHERE id = ?", (id,))
db.commit()
# Finalize by replacing real stash file
tmp_path.replace(stash_path)
return redirect(f"/storage/{stash_name}", code=303)

View File

@@ -0,0 +1,30 @@
body {
background-color: #000;
font-size: large;
font-family: sans-serif;
color: rgb(240, 240, 240);
}
.item .name {
font-weight: bold;
}
.item {
background-color: #444;
}
.color-rare {
color: rgb(255, 255, 100);
}
.color-unique {
color: rgb(199, 179, 119);
}
.color-set {
color: rgb(0, 252, 0);
}
.color-runeword {
color: rgb(199, 179, 119);
}

View File

@@ -0,0 +1,14 @@
<div class="item">
<input type="checkbox" name="item_{{tabloop.index0}}_{{itemloop.index0}}" value="remove" /> ({{tabloop.index0}}, {{itemloop.index0}})
<ul>
<li class="name color-{{item.color}}">{{item.name}}</li>
{% if item.quality and item.quality >= 5 %}
<li class="name color-{{item.color}}">{{item.basename}}</li>
{% endif %}
{% if item.stats %}
{% for stat in item.stats %}
<li>{{stat}}</li>
{% endfor %}
{% endif %}
</ul>
</div>

View File

@@ -0,0 +1,24 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>Shared Stash</title>
<link rel="stylesheet" href="/static/style.css" />
<head>
<body>
<form action="/stash/{{stash_name}}/store" method="POST">
{% for tab in stash.tabs %}
<div>
{% set tabloop = loop %}
<h2>Tab {{tabloop.index}}</h2>
{% for item in tab.items %}
{% set itemloop = loop %}
{% include "item.html" %}
{% endfor %}
</div>
{% endfor %}
<input type="submit" value="Store items">
<input type="hidden" name="stash_hash" value="{{stash_hash}}" />
</form>
</body>
</html>

View File

@@ -0,0 +1,23 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>Storage</title>
<link rel="stylesheet" href="/static/style.css" />
<head>
<body>
<form action="/storage/{{stash_name}}/take" method="POST">
<div>
<!-- TODO: Include item.html -->
{% for db_id, item in storage_items.items() %}
<div>
<input type="checkbox" name="item_{{db_id}}" value="take" />
{{item.name}}
({{db_id}})
</div>
{% endfor %}
</div>
<input type="submit" value="Take items">
</form>
</body>
</html>

View File

@@ -17,7 +17,6 @@
import json import json
import os import os
import re import re
from typing import Optional
from bitarray import bitarray from bitarray import bitarray
from bitarray.util import int2ba from bitarray.util import int2ba
from dataclasses import dataclass from dataclasses import dataclass
@@ -118,7 +117,8 @@ class Item:
defense: int | None = None defense: int | None = None
durability: int | None = None durability: int | None = None
max_durability: int | None = None max_durability: int | None = None
sockets: list[Optional["Item"]] | None = None sockets: int | None = None
socketed_items: list["Item"] | None = None
quantity: int | None = None quantity: int | None = None
stats: list[Stat] | None = None stats: list[Stat] | None = None
@@ -180,9 +180,8 @@ class Item:
def raw(self): def raw(self):
parts = [self.raw_data] parts = [self.raw_data]
if self.sockets: if self.socketed_items:
for item in self.sockets: for item in self.socketed_items:
if item:
parts.append(item.raw_data) parts.append(item.raw_data)
return b"".join(parts) return b"".join(parts)
@@ -232,12 +231,9 @@ class Item:
f"Durability: {self.durability} out of {self.max_durability}", f"Durability: {self.durability} out of {self.max_durability}",
) )
if self.is_socketed: if self.is_socketed:
print(" " * indent, f"{len(self.sockets)} sockets:") print(" " * indent, f"{len(self.socketed_items)}/{self.sockets} sockets:")
for socket in self.sockets: for socket in self.socketed_items:
if socket:
socket.print(indent + 4) socket.print(indent + 4)
else:
print(" " * (indent + 4), "Empty")
if self.quantity: if self.quantity:
print(" " * indent, f"Quantity: {self.quantity}") print(" " * indent, f"Quantity: {self.quantity}")
if self.stats: if self.stats:
@@ -293,9 +289,9 @@ class Item:
reqs["lvl"] = max(reqs["lvl"], affix["req_lvl"]) reqs["lvl"] = max(reqs["lvl"], affix["req_lvl"])
if affix["req_class"]: if affix["req_class"]:
reqs["class"] = affix["req_class"] reqs["class"] = affix["req_class"]
if self.sockets: if self.socketed_items:
for socket in self.sockets: for socket_item in self.socketed_items:
socket_reqs = socket.requirements() socket_reqs = socket_item.requirements()
reqs["lvl"] = max(reqs["lvl"], socket_reqs["lvl"]) reqs["lvl"] = max(reqs["lvl"], socket_reqs["lvl"])
return reqs return reqs
@@ -345,9 +341,10 @@ class Item:
"""INSERT INTO item_extra (item_id, item_name, """INSERT INTO item_extra (item_id, item_name,
set_name, uid, lvl, quality, graphic, implicit, low_quality, set_id, set_name, uid, lvl, quality, graphic, implicit, low_quality, set_id,
unique_id, nameword1, nameword2, runeword_id, personal_name, unique_id, nameword1, nameword2, runeword_id, personal_name,
defense, durability, max_durability, quantity, req_str, req_dex, defense, durability, max_durability, sockets, quantity, req_str,
req_class) req_dex, req_class)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?,
?, ?, ?)""",
( (
item_id, item_id,
name, name,
@@ -367,6 +364,7 @@ class Item:
self.defense, self.defense,
self.durability, self.durability,
self.max_durability, self.max_durability,
self.sockets,
self.quantity, self.quantity,
req["str"], req["str"],
req["dex"], req["dex"],
@@ -399,9 +397,9 @@ class Item:
), ),
) )
if self.sockets: if self.socketed_items:
for socket in self.sockets: for socket_item in self.socketed_items:
socket.write_to_db(socketed_into=item_id, commit=False) socket_item.write_to_db(socketed_into=item_id, commit=False, db=db)
if commit: if commit:
db.commit() db.commit()
@@ -416,7 +414,7 @@ class Item:
is_beginner, is_simple, is_ethereal, is_personalized, is_runeword, code, is_beginner, is_simple, is_ethereal, is_personalized, is_runeword, code,
uid, lvl, quality, graphic, implicit, low_quality, set_id, unique_id, uid, lvl, quality, graphic, implicit, low_quality, set_id, unique_id,
nameword1, nameword2, runeword_id, personal_name, defense, durability, nameword1, nameword2, runeword_id, personal_name, defense, durability,
max_durability, quantity max_durability, sockets, quantity
FROM item LEFT JOIN item_extra ON id = item_id WHERE id = ?""", FROM item LEFT JOIN item_extra ON id = item_id WHERE id = ?""",
(id,), (id,),
).fetchone() ).fetchone()
@@ -453,6 +451,7 @@ class Item:
defense=row["defense"], defense=row["defense"],
durability=row["durability"], durability=row["durability"],
max_durability=row["max_durability"], max_durability=row["max_durability"],
sockets=row["sockets"],
quantity=row["quantity"], quantity=row["quantity"],
) )
@@ -468,15 +467,16 @@ class Item:
else: else:
item.suffixes.append(row["affix_id"]) item.suffixes.append(row["affix_id"])
if item.is_socketed:
item.socketed_items = []
rows = db.execute( rows = db.execute(
"SELECT id FROM item WHERE socketed_into = ?", (id,) "SELECT id FROM item WHERE socketed_into = ?", (id,)
).fetchall() ).fetchall()
if len(rows) > 0: if len(rows) > 0:
item.sockets = []
for row in rows: for row in rows:
socket = Item.load_from_db(row["id"]) socket_item = Item.load_from_db(row["id"], db=db)
socket.pos_x = len(item.sockets) socket_item.pos_x = len(item.socketed_items)
item.sockets.append(socket) item.socketed_items.append(socket_item)
rows = db.execute( rows = db.execute(
"SELECT stat, value1, value2, value3, parameter FROM item_stat WHERE item_id = ?", "SELECT stat, value1, value2, value3, parameter FROM item_stat WHERE item_id = ?",

View File

@@ -179,7 +179,8 @@ def parse_item(data: bytes) -> tuple[bytes, Item]:
itembase_end += personalized_end itembase_end += personalized_end
if item.is_socketed: if item.is_socketed:
sockets_count = ba2int(bits[itembase_end : itembase_end + 4]) item.sockets = ba2int(bits[itembase_end : itembase_end + 4])
item.socketed_items = []
sockets_end = itembase_end + 4 sockets_end = itembase_end + 4
else: else:
sockets_end = itembase_end sockets_end = itembase_end
@@ -194,10 +195,9 @@ def parse_item(data: bytes) -> tuple[bytes, Item]:
# Parse out sockets if any exist on the item # Parse out sockets if any exist on the item
if item.is_socketed: if item.is_socketed:
item.sockets = [None] * sockets_count
for i in range(0, filled_sockets): for i in range(0, filled_sockets):
remaining_data, socket = parse_item(remaining_data) remaining_data, socket_item = parse_item(remaining_data)
item.sockets[i] = socket item.socketed_items.append(socket_item)
return remaining_data, item return remaining_data, item

View File

@@ -59,7 +59,7 @@ CREATE TABLE item_extra (
defense INTEGER DEFAULT NULL, defense INTEGER DEFAULT NULL,
durability INTEGER DEFAULT NULL, durability INTEGER DEFAULT NULL,
max_durability INTEGER DEFAULT NULL, max_durability INTEGER DEFAULT NULL,
-- sockets: list[Optional["Item"]] | None = None => see item.socketed_into sockets INTEGER DEFAULT NULL, -- number of sockets; see item.socketed_into
quantity INTEGER DEFAULT NULL, quantity INTEGER DEFAULT NULL,
-- stats: list[Stat] | None = None => see table 'item_stat' -- stats: list[Stat] | None = None => see table 'item_stat'

View File

@@ -34,6 +34,17 @@ class Stash:
def raw(self) -> bytes: def raw(self) -> bytes:
return b"".join(tab.raw() for tab in self.tabs) return b"".join(tab.raw() for tab in self.tabs)
def add(self, item: Item) -> None:
for tab in self.tabs:
try:
tab.add(item)
return
except StashFullError:
pass
raise StashFullError(
"Could not locate an open spot in the stash to add the item"
)
class StashTab: class StashTab:
def __init__(self) -> None: def __init__(self) -> None:

View File

@@ -26,17 +26,30 @@ class DbTest(unittest.TestCase):
) )
_, item = parse_item(data) _, item = parse_item(data)
db_id = item.write_to_db() db = get_db()
loaded_item = Item.load_from_db(db_id) db_id = item.write_to_db(db=db)
loaded_item = Item.load_from_db(db_id, db=db)
self.assertEqual(item, loaded_item) self.assertEqual(item, loaded_item)
# Check that requirement was written properly # Check that requirement was written properly
db = get_db()
reqs = db.execute( reqs = db.execute(
"SELECT req_lvl, req_str, req_dex, req_class FROM item JOIN item_extra ON id = item_id WHERE id = 1" "SELECT req_lvl, req_str, req_dex, req_class FROM item JOIN item_extra ON id = item_id WHERE id = ?",
(db_id,),
).fetchone() ).fetchone()
expected_reqs = item.requirements() expected_reqs = item.requirements()
self.assertEqual(reqs["req_lvl"], expected_reqs["lvl"]) self.assertEqual(reqs["req_lvl"], expected_reqs["lvl"])
self.assertEqual(reqs["req_str"], expected_reqs["str"]) self.assertEqual(reqs["req_str"], expected_reqs["str"])
self.assertEqual(reqs["req_dex"], expected_reqs["dex"]) self.assertEqual(reqs["req_dex"], expected_reqs["dex"])
self.assertEqual(reqs["req_class"], expected_reqs["class"]) self.assertEqual(reqs["req_class"], expected_reqs["class"])
def test_empty_sockets(self):
# superior armor with empty sockets
data = bytes.fromhex("10088000050014df175043b1b90cc38d80e3834070b004f41f")
_, item = parse_item(data)
db = get_db()
db_id = item.write_to_db(db=db)
loaded_item = Item.load_from_db(db_id, db=db)
self.assertEqual(loaded_item.sockets, 2)
self.assertEqual(len(loaded_item.socketed_items), 0)
self.assertEqual(item, loaded_item)

View File

@@ -68,7 +68,8 @@ class ParseItemTest(unittest.TestCase):
self.assertEqual(data, b"") self.assertEqual(data, b"")
self.assertEqual(item.quality, Quality.HIGH) self.assertEqual(item.quality, Quality.HIGH)
self.assertEqual(len(item.stats), 2) self.assertEqual(len(item.stats), 2)
self.assertEqual(len(item.sockets), 2) self.assertEqual(item.sockets, 2)
self.assertEqual(len(item.socketed_items), 0)
def test_ed_max(self): def test_ed_max(self):
# test bugfix for https://gitlab.com/omicron-oss/d2warehouse/-/issues/1 # test bugfix for https://gitlab.com/omicron-oss/d2warehouse/-/issues/1
@@ -89,11 +90,9 @@ class ParseItemTest(unittest.TestCase):
data, item = parse_item(data) data, item = parse_item(data)
self.assertEqual(data, b"") self.assertEqual(data, b"")
self.assertTrue(item.is_runeword) self.assertTrue(item.is_runeword)
self.assertEqual(len(item.sockets), 2) self.assertEqual(item.sockets, 2)
item.sockets[0].print() self.assertEqual(item.socketed_items[0].code, "r09")
item.sockets[1].print() self.assertEqual(item.socketed_items[1].code, "r12")
self.assertEqual(item.sockets[0].code, "r09")
self.assertEqual(item.sockets[1].code, "r12")
rw = lookup_runeword(item.runeword_id) rw = lookup_runeword(item.runeword_id)
self.assertEqual(rw["name"], "Lore") self.assertEqual(rw["name"], "Lore")
self.assertEqual(str(item.stats[4]), "+1 to All Skills") # runeword stat self.assertEqual(str(item.stats[4]), "+1 to All Skills") # runeword stat

View File

@@ -18,6 +18,7 @@ requires-python = ">=3.10"
license = {text = "GPLv3 License"} license = {text = "GPLv3 License"}
dependencies = [ dependencies = [
"bitarray", "bitarray",
"flask",
] ]
dynamic = ["version"] dynamic = ["version"]