SwapStation_WebApp/backend/core/protobuf_decoder.py

158 lines
5.6 KiB
Python

import json
from google.protobuf.json_format import MessageToDict
from google.protobuf.message import DecodeError
# Import the specific message types from your generated protobuf file
# Make sure the path 'proto.your_proto_file_pb2' matches your file structure
from proto.vec_payload_chgSt_pb2 import (
mainPayload as PeriodicData,
eventPayload as EventData,
rpcRequest as RpcRequest,
eventType_e,
jobType_e,
languageType_e
)
class ProtobufDecoder:
"""
Handles the decoding of different Protobuf message types with robust error handling.
"""
def decode_periodic(self, payload: bytes) -> dict | None:
"""
Decodes a binary payload into a PeriodicData dictionary.
"""
try:
message = PeriodicData()
message.ParseFromString(payload)
return MessageToDict(message, preserving_proto_field_name=True)
except DecodeError as e:
print(f"Error decoding PeriodicData: {e}")
return None
except Exception as e:
print(f"An unexpected error occurred during periodic decoding: {e}")
return None
def decode_event(self, payload_bytes: bytes) -> dict | None:
"""
Decodes an event payload robustly, ensuring the correct eventType is used.
"""
try:
# 1. Standard parsing to get a base dictionary
msg = EventData()
msg.ParseFromString(payload_bytes)
d = MessageToDict(msg, preserving_proto_field_name=True)
# 2. Manually extract the true enum value from the raw bytes
wire_num = self._extract_field3_varint(payload_bytes)
wire_name = None
if wire_num is not None:
try:
wire_name = eventType_e.Name(wire_num)
except ValueError:
wire_name = f"UNKNOWN_ENUM_VALUE_{wire_num}"
# 3. Always prefer the manually extracted "wire value"
if wire_name:
d["eventType"] = wire_name
# 4. Ensure consistent structure with default values
ed = d.setdefault("eventData", {})
ed.setdefault("nfcData", None)
ed.setdefault("batteryIdentification", "")
ed.setdefault("activityFailureReason", 0)
ed.setdefault("swapAbortReason", "ABORT_UNKNOWN")
ed.setdefault("swapTime", 0)
ed.setdefault("faultCode", 0)
ed.setdefault("doorStatus", 0)
ed.setdefault("slotId", 0)
# 5. Reorder for clean logs and return as a dictionary
return {
"ts": d.get("ts"),
"deviceId": d.get("deviceId"),
"eventType": d.get("eventType"),
"sessionId": d.get("sessionId"),
"eventData": d.get("eventData"),
}
except Exception as e:
print(f"An unexpected error occurred during event decoding: {e}")
return None
def decode_rpc_request(self, payload_bytes: bytes) -> dict | None:
"""
Decodes an RPC request payload robustly, ensuring the correct jobType is used.
"""
try:
# 1. Standard parsing
msg = RpcRequest()
msg.ParseFromString(payload_bytes)
d = MessageToDict(msg, preserving_proto_field_name=True)
# 2. Manually extract the true enum value for jobType (field 3)
wire_num = self._extract_field3_varint(payload_bytes)
wire_name = None
if wire_num is not None:
try:
wire_name = jobType_e.Name(wire_num)
except ValueError:
wire_name = f"UNKNOWN_ENUM_VALUE_{wire_num}"
# 3. Prefer the manually extracted value
if wire_name:
d["jobType"] = wire_name
# 4. Ensure consistent structure
d.setdefault("rpcData", None)
d.setdefault("slotInfo", None)
return d
except Exception as e:
print(f"An unexpected error occurred during RPC request decoding: {e}")
return None
# --- Helper methods for manual byte parsing ---
def _read_varint(self, b: bytes, i: int):
"""Helper to read a varint from a raw byte buffer."""
shift = 0
val = 0
while True:
if i >= len(b): raise ValueError("truncated varint")
c = b[i]
i += 1
val |= (c & 0x7F) << shift
if not (c & 0x80): break
shift += 7
if shift > 64: raise ValueError("varint too long")
return val, i
def _skip_field(self, b: bytes, i: int, wt: int):
"""Helper to skip a field in the buffer based on its wire type."""
if wt == 0: # VARINT
_, i = self._read_varint(b, i)
return i
if wt == 1: # 64-BIT
return i + 8
if wt == 2: # LENGTH-DELIMITED
ln, i = self._read_varint(b, i)
return i + ln
if wt == 5: # 32-BIT
return i + 4
raise ValueError(f"unsupported wire type to skip: {wt}")
def _extract_field3_varint(self, b: bytes):
"""Manually parses the byte string to find the integer value of field number 3 (e.g., eventType, jobType)."""
i = 0
n = len(b)
while i < n:
key, i2 = self._read_varint(b, i)
wt = key & 0x7
fn = key >> 3
i = i2
if fn == 3 and wt == 0:
v, _ = self._read_varint(b, i)
return v
i = self._skip_field(b, i, wt)
return None