1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tooling to generate C++ constants from a yaml sensor definition.""" 15 16import argparse 17from dataclasses import dataclass, fields, is_dataclass 18from collections.abc import Sequence 19import io 20import re 21import sys 22from typing import Type, List, Any 23import typing 24 25import yaml 26 27 28def kid_from_name(name: str) -> str: 29 """ 30 Generate a const style ID name from a given name string. Example: 31 If name is "sample_rate", the ID would be kSampleRate 32 33 Args: 34 name: the name to convert to an ID 35 Returns: 36 C++ style 'k' prefixed camel cased ID 37 """ 38 return "k" + ''.join(ele.title() for ele in re.split(r"[\s_\-\,]+", name)) 39 40 41class Printable: 42 """Common printable object""" 43 44 def __init__( 45 self, item_id: str, name: str, description: str | None 46 ) -> None: 47 self.id: str = item_id 48 self.name: str = name 49 self.description: str | None = description 50 51 def __hash__(self) -> int: 52 return hash((self.id, self.name, self.description)) 53 54 @property 55 def variable_name(self) -> str: 56 """Convert the 'id' to a constant variable name in C++.""" 57 return kid_from_name(self.id) 58 59 def print(self, writer: io.TextIOWrapper) -> None: 60 """ 61 Baseclass for a printable sensor object 62 63 Args: 64 writer: IO writer used to print values. 65 """ 66 writer.write( 67 f""" 68/// @brief {self.name} 69""" 70 ) 71 if self.description: 72 writer.write( 73 f"""/// 74/// {self.description} 75""" 76 ) 77 78 79@dataclass 80class UnitsSpec: 81 """Typing for the Units definition dictionary.""" 82 83 name: str 84 symbol: str 85 86 87class Units(Printable): 88 """A single unit representation""" 89 90 symbol: str 91 92 def __init__(self, unit_id: str, definition: UnitsSpec) -> None: 93 """ 94 Create a new unit object 95 96 Args: 97 unit_id: The ID of the unit 98 definition: A dictionary of the unit definition 99 """ 100 super().__init__( 101 item_id=unit_id, 102 name=definition.name, 103 description=definition.name, 104 ) 105 self.symbol: str = definition.symbol 106 107 def __hash__(self) -> int: 108 return hash((super().__hash__(), self.symbol)) 109 110 def __eq__(self, value: object) -> bool: 111 if not isinstance(value, Units): 112 return False 113 return ( 114 self.id == value.id 115 and self.name == value.name 116 and self.description == value.description 117 and self.symbol == value.symbol 118 ) 119 120 def print(self, writer: io.TextIOWrapper) -> None: 121 """ 122 Print header definition for this unit 123 124 Args: 125 writer: IO writer used to print values. 126 """ 127 super().print(writer=writer) 128 writer.write( 129 f"""PW_SENSOR_UNIT_TYPE( 130 {super().variable_name}, 131 "PW_SENSOR_UNITS_TYPE", 132 "{self.name}", 133 "{self.symbol}" 134); 135""" 136 ) 137 138 139@dataclass 140class AttributeSpec: 141 """Typing for the Attribute definition dictionary.""" 142 143 name: str 144 description: str 145 146 147class Attribute(Printable): 148 """A single attribute representation.""" 149 150 def __init__(self, attr_id: str, definition: AttributeSpec) -> None: 151 super().__init__( 152 item_id=attr_id, 153 name=definition.name, 154 description=definition.description, 155 ) 156 157 def print(self, writer: io.TextIOWrapper) -> None: 158 """ 159 Print header definition for this attribute 160 161 Args: 162 writer: IO writer used to print values. 163 """ 164 super().print(writer=writer) 165 writer.write( 166 f"""PW_SENSOR_ATTRIBUTE_TYPE( 167 {super().variable_name}, 168 "PW_SENSOR_ATTRIBUTE_TYPE", 169 "{self.name}" 170); 171""" 172 ) 173 174 175@dataclass 176class ChannelSpec: 177 """Typing for the Channel definition dictionary.""" 178 179 name: str 180 description: str 181 units: str 182 183 184class Channel(Printable): 185 """A single channel representation.""" 186 187 def __init__( 188 self, channel_id: str, definition: ChannelSpec, units: dict[str, Units] 189 ) -> None: 190 super().__init__( 191 item_id=channel_id, 192 name=definition.name, 193 description=definition.description, 194 ) 195 self.units: Units = units[definition.units] 196 197 def __hash__(self) -> int: 198 return hash((super().__hash__(), self.units)) 199 200 def __eq__(self, value: object) -> bool: 201 if not isinstance(value, Channel): 202 return False 203 return ( 204 self.id == value.id 205 and self.name == value.name 206 and self.description == value.description 207 and self.units == value.units 208 ) 209 210 def print(self, writer: io.TextIOWrapper) -> None: 211 """ 212 Print header definition for this channel 213 214 Args: 215 writer: IO writer used to print values. 216 """ 217 super().print(writer=writer) 218 writer.write( 219 f""" 220PW_SENSOR_MEASUREMENT_TYPE({super().variable_name}, 221 "PW_SENSOR_MEASUREMENT_TYPE", 222 "{self.name}", 223 ::pw::sensor::units::{self.units.variable_name} 224);""" 225 ) 226 227 228@dataclass 229class TriggerSpec: 230 """Typing for the Trigger definition dictionary.""" 231 232 name: str 233 description: str 234 235 236class Trigger(Printable): 237 """A single trigger representation.""" 238 239 def __init__(self, trigger_id: str, definition: TriggerSpec) -> None: 240 super().__init__( 241 item_id=trigger_id, 242 name=definition.name, 243 description=definition.description, 244 ) 245 246 def print(self, writer: io.TextIOWrapper) -> None: 247 """ 248 Print header definition for this trigger 249 250 Args: 251 writer: IO writer used to print values. 252 """ 253 super().print(writer=writer) 254 writer.write( 255 f"""PW_SENSOR_TRIGGER_TYPE( 256 {super().variable_name}, 257 "PW_SENSOR_TRIGGER_TYPE", 258 "{self.name}" 259); 260""" 261 ) 262 263 264@dataclass 265class SensorAttributeSpec: 266 """Typing for the SensorAttribute definition dictionary.""" 267 268 channel: str 269 attribute: str 270 units: str 271 272 273class SensorAttribute(Printable): 274 """An attribute instance belonging to a sensor""" 275 276 @staticmethod 277 def id_from_definition(definition: SensorAttributeSpec) -> str: 278 """ 279 Get a unique ID for the channel/attribute pair (not sensor specific) 280 281 Args: 282 definition: A dictionary of the attribute definition 283 Returns: 284 String representation for the channel/attribute pair 285 """ 286 return f"{definition.channel}-{definition.attribute}" 287 288 @staticmethod 289 def name_from_definition(definition: SensorAttributeSpec) -> str: 290 """ 291 Get a unique name for the channel/attribute pair (not sensor specific) 292 293 Args: 294 definition: A dictionary of the attribute definition 295 Returns: 296 String representation of the human readable name for the 297 channel/attribute pair. 298 """ 299 return f"{definition.channel}'s {definition.attribute} attribute" 300 301 @staticmethod 302 def description_from_definition(definition: SensorAttributeSpec) -> str: 303 """ 304 Get the description for the channel/attribute pair (not sensor specific) 305 306 Args: 307 definition: A dictionary of the attribute definition 308 Returns: 309 A description string for the channel/attribute pair. 310 """ 311 return ( 312 f"Allow the configuration of the {definition.channel}'s " 313 + f"{definition.attribute} attribute" 314 ) 315 316 def __init__(self, definition: SensorAttributeSpec) -> None: 317 super().__init__( 318 item_id=SensorAttribute.id_from_definition(definition=definition), 319 name=SensorAttribute.name_from_definition(definition=definition), 320 description=SensorAttribute.description_from_definition( 321 definition=definition 322 ), 323 ) 324 self.attribute: str = definition.attribute 325 self.channel: str = definition.channel 326 self.units: str = definition.units 327 328 def __hash__(self) -> int: 329 return hash( 330 (super().__hash__(), self.attribute, self.channel, self.units) 331 ) 332 333 def print(self, writer: io.TextIOWrapper) -> None: 334 super().print(writer) 335 writer.write( 336 f""" 337PW_SENSOR_ATTRIBUTE_INSTANCE({self.variable_name}, 338 channels::{kid_from_name(self.channel)}, 339 attributes::{kid_from_name(self.attribute)}, 340 units::{kid_from_name(self.units)}); 341""" 342 ) 343 344 345@dataclass 346class CompatibleSpec: 347 """Typing for the Compatible dictionary.""" 348 349 org: str 350 part: str 351 352 353@dataclass 354class SensorSpec: 355 """Typing for the Sensor definition dictionary.""" 356 357 description: str 358 compatible: CompatibleSpec 359 supported_buses: List[str] 360 attributes: List[SensorAttributeSpec] 361 channels: dict[str, List[ChannelSpec]] 362 triggers: List[Any] 363 extras: dict[str, Any] 364 365 366class Sensor(Printable): 367 """Represent a single sensor type instance""" 368 369 @staticmethod 370 def sensor_id_to_name(sensor_id: str) -> str: 371 """ 372 Convert a sensor ID to a human readable name 373 374 Args: 375 sensor_id: The ID of the sensor 376 Returns: 377 Human readable name based on the ID 378 """ 379 return sensor_id.replace(',', ' ') 380 381 def __init__(self, item_id: str, definition: SensorSpec) -> None: 382 super().__init__( 383 item_id=item_id, 384 name=Sensor.sensor_id_to_name(item_id), 385 description=definition.description, 386 ) 387 self.compatible_org: str = definition.compatible.org 388 self.compatible_part: str = definition.compatible.part 389 self.chan_count: int = len(definition.channels) 390 self.attr_count: int = len(definition.attributes) 391 self.trig_count: int = len(definition.triggers) 392 self.attributes: Sequence[SensorAttribute] = [ 393 SensorAttribute(definition=spec) for spec in definition.attributes 394 ] 395 396 @property 397 def namespace(self) -> str: 398 """ 399 The namespace which owns the sensor (the org from the compatible) 400 401 Returns: 402 The C++ style namespace name of the org. 403 """ 404 return self.compatible_org.replace("-", "_") 405 406 def __hash__(self) -> int: 407 return hash( 408 ( 409 super().__hash__(), 410 self.compatible_org, 411 self.compatible_part, 412 self.chan_count, 413 self.attr_count, 414 self.trig_count, 415 ) 416 ) 417 418 def print(self, writer: io.TextIOWrapper) -> None: 419 """ 420 Print header definition for this trigger 421 422 Args: 423 writer: IO writer used to print values. 424 """ 425 writer.write( 426 f""" 427namespace {self.namespace} {{ 428class {self.compatible_part.upper()} 429 : public pw::sensor::zephyr::ZephyrSensor<{len(self.attributes)}> {{ 430 public: 431 {self.compatible_part.upper()}(const struct device* dev) 432 : pw::sensor::zephyr::ZephyrSensor<{len(self.attributes)}>( 433 dev, 434 {{""" 435 ) 436 for attribute in self.attributes: 437 writer.write( 438 f""" 439 Attribute::Build<{attribute.variable_name}>(),""" 440 ) 441 writer.write( 442 f""" 443 }}) {{}} 444}}; 445}} // namespace {self.namespace} 446""" 447 ) 448 449 450@dataclass 451class Args: 452 """CLI arguments""" 453 454 package: Sequence[str] 455 language: str 456 zephyr: bool 457 458 459class CppHeader: 460 """Generator for a C++ header""" 461 462 def __init__( 463 self, 464 using_zephyr: bool, 465 package: Sequence[str], 466 attributes: Sequence[Attribute], 467 channels: Sequence[Channel], 468 triggers: Sequence[Trigger], 469 units: Sequence[Units], 470 sensors: Sequence[Sensor], 471 ) -> None: 472 """ 473 Args: 474 package: The package name used in the output. In C++ we'll convert 475 this to a namespace. 476 attributes: A sequence of attributes which will be exposed in the 477 'attributes' namespace 478 channels: A sequence of channels which will be exposed in the 479 'channels' namespace 480 triggers: A sequence of triggers which will be exposed in the 481 'triggers' namespace 482 units: A sequence of units which should be exposed in the 'units' 483 namespace 484 """ 485 self._using_zephyr: bool = using_zephyr 486 self._package: str = '::'.join(package) 487 self._attributes: Sequence[Attribute] = attributes 488 self._channels: Sequence[Channel] = channels 489 self._triggers: Sequence[Trigger] = triggers 490 self._units: Sequence[Units] = units 491 self._sensors: Sequence[Sensor] = sensors 492 self._sensor_attributes: set[SensorAttribute] = set() 493 for sensor in sensors: 494 self._sensor_attributes.update(sensor.attributes) 495 496 def __str__(self) -> str: 497 writer = io.StringIO() 498 self._print_header(writer=writer) 499 self._print_constants(writer=writer) 500 self._print_footer(writer=writer) 501 return writer.getvalue() 502 503 def _print_header(self, writer: io.TextIOWrapper) -> None: 504 """ 505 Print the top part of the .h file (pragma, includes, and namespace) 506 507 Args: 508 writer: Where to write the text to 509 """ 510 writer.write( 511 "/* Auto-generated file, do not edit */\n" 512 "#pragma once\n" 513 "\n" 514 "#include \"pw_sensor/types.h\"\n" 515 ) 516 if self._package: 517 writer.write(f"namespace {self._package} {{\n\n") 518 519 def _print_constants(self, writer: io.TextIOWrapper) -> None: 520 """ 521 Print the constants definitions from self._attributes, self._channels, 522 and self._trigger 523 524 Args: 525 writer: Where to write the text 526 """ 527 528 self._print_in_namespace( 529 namespace="units", printables=self._units, writer=writer 530 ) 531 self._print_in_namespace( 532 namespace="attributes", printables=self._attributes, writer=writer 533 ) 534 self._print_in_namespace( 535 namespace="channels", printables=self._channels, writer=writer 536 ) 537 self._print_in_namespace( 538 namespace="triggers", printables=self._triggers, writer=writer 539 ) 540 for sensor_attribute in self._sensor_attributes: 541 sensor_attribute.print(writer=writer) 542 543 @staticmethod 544 def _print_in_namespace( 545 namespace: str, 546 printables: Sequence[Printable], 547 writer: io.TextIOWrapper, 548 ) -> None: 549 """ 550 Print constants definitions wrapped in a namespace 551 Args: 552 namespace: The namespace to use 553 printables: A sequence of printable objects 554 writer: Where to write the text 555 """ 556 writer.write(f"\nnamespace {namespace} {{\n") 557 for printable in printables: 558 printable.print(writer=writer) 559 writer.write(f"\n}} // namespace {namespace}\n") 560 561 def _print_footer(self, writer: io.TextIOWrapper) -> None: 562 """ 563 Write the bottom part of the .h file (closing namespace) 564 565 Args: 566 writer: Where to write the text 567 """ 568 if self._package: 569 writer.write(f"\n}} // namespace {self._package}") 570 571 if self._using_zephyr: 572 self._print_zephyr_mapping(writer=writer) 573 574 def _print_zephyr_mapping(self, writer: io.TextIOWrapper) -> None: 575 """ 576 Generate Zephyr type maps for channels, attributes, and triggers. 577 578 Args: 579 writer: Where to write the text 580 """ 581 writer.write( 582 f""" 583#include <zephyr/generated/sensor_constants.h> 584#include \"pw_containers/flat_map.h\" 585 586namespace pw::sensor::zephyr {{ 587 588class ZephyrAttributeMap 589 : public pw::containers::FlatMap<uint32_t, uint32_t, 590 {len(self._attributes)}> {{ 591 public: 592 constexpr ZephyrAttributeMap() 593 : pw::containers::FlatMap<uint32_t, uint32_t, 594 {len(self._attributes)}>({{{{""" 595 ) 596 for attribute in self._attributes: 597 attribute_type = ( 598 f"{self._package}::attributes::" 599 + f"{attribute.variable_name}::kAttributeType" 600 ) 601 writer.write( 602 f""" 603 {{{attribute_type}, 604 SENSOR_ATTR_{attribute.id.upper()}}},""" 605 ) 606 writer.write("\n }}) {}\n};") 607 writer.write( 608 f""" 609 610class ZephyrChannelMap 611 : public pw::containers::FlatMap<uint32_t, uint32_t, 612 {len(self._channels)}> {{ 613 public: 614 constexpr ZephyrChannelMap() 615 : pw::containers::FlatMap<uint32_t, uint32_t, 616 {len(self._channels)}>({{{{""" 617 ) 618 for channel in self._channels: 619 measurement_name = ( 620 f"{self._package}::channels::" 621 + f"{channel.variable_name}::kMeasurementName" 622 ) 623 writer.write( 624 f""" 625 {{{measurement_name}, 626 SENSOR_CHAN_{channel.id.upper()}}},""" 627 ) 628 writer.write( 629 """ 630 }}) {} 631}; 632 633extern ZephyrAttributeMap kAttributeMap; 634extern ZephyrChannelMap kChannelMap; 635 636} // namespace pw::sensor::zephyr 637 638#include "pw_sensor_zephyr/sensor.h" 639namespace pw::sensor { 640""" 641 ) 642 for sensor in self._sensors: 643 sensor.print(writer=writer) 644 writer.write( 645 """ 646} // namespace pw::sensor 647""" 648 ) 649 650 651@dataclass 652class InputSpec: 653 """Typing for the InputData spec dictionary""" 654 655 units: dict[str, UnitsSpec] 656 attributes: dict[str, AttributeSpec] 657 channels: dict[str, ChannelSpec] 658 triggers: dict[str, TriggerSpec] 659 sensors: dict[str, SensorSpec] 660 661 662class InputData: 663 """ 664 Wrapper class for all the input data parsed out into: Units, Attribute, 665 Channel, Trigger, and Sensor types (or sub-types). 666 """ 667 668 def __init__( 669 self, 670 spec: InputSpec, 671 units_type: Type[Units] = Units, 672 attribute_type: Type[Attribute] = Attribute, 673 channel_type: Type[Channel] = Channel, 674 trigger_type: Type[Trigger] = Trigger, 675 sensor_type: Type[Sensor] = Sensor, 676 ) -> None: 677 """ 678 Parse the input spec and create all the input data types. 679 680 Args: 681 spec: The input spec dictionary 682 units_type: The type to use for units 683 attribute_type: The type to use for attributes 684 channel_type: The type to use for channels 685 trigger_type: The type to use for triggers 686 sensor_type: The type to use for sensors 687 """ 688 self.all_attributes: set[Attribute] = set() 689 self.all_channels: set[Channel] = set() 690 self.all_triggers: set[Trigger] = set() 691 self.all_units: dict[str, Units] = {} 692 self.all_sensors: set[Sensor] = set() 693 for units_id, units_spec in spec.units.items(): 694 units = units_type(unit_id=units_id, definition=units_spec) 695 assert units not in self.all_units.values() 696 self.all_units[units_id] = units 697 for attribute_id, attribute_spec in spec.attributes.items(): 698 attribute = attribute_type( 699 attr_id=attribute_id, definition=attribute_spec 700 ) 701 assert attribute not in self.all_attributes 702 self.all_attributes.add(attribute) 703 for channel_id, channel_spec in spec.channels.items(): 704 channel = channel_type( 705 channel_id=channel_id, 706 definition=channel_spec, 707 units=self.all_units, 708 ) 709 assert channel not in self.all_channels 710 self.all_channels.add(channel) 711 for trigger_id, trigger_spec in spec.triggers.items(): 712 trigger = trigger_type( 713 trigger_id=trigger_id, definition=trigger_spec 714 ) 715 assert trigger not in self.all_triggers 716 self.all_triggers.add(trigger) 717 for sensor_id, sensor_spec in spec.sensors.items(): 718 sensor = sensor_type(item_id=sensor_id, definition=sensor_spec) 719 assert sensor not in self.all_sensors 720 self.all_sensors.add(sensor) 721 722 723def is_list_type(t) -> bool: 724 """ 725 Checks if the given type `t` is either a list or typing.List. 726 727 Args: 728 t: The type to check. 729 Returns: 730 True if `t` is a list type, False otherwise. 731 """ 732 origin = typing.get_origin(t) 733 return origin is list or (origin is list and typing.get_args(t) == ()) 734 735 736def is_primitive(value): 737 """Checks if the given value is of a primitive type. 738 739 Args: 740 value: The value to check. 741 742 Returns: 743 True if the value is of a primitive type, False otherwise. 744 """ 745 return isinstance(value, (int, float, complex, str, bool)) 746 747 748def create_dataclass_from_dict(cls, data, indent: int = 0): 749 """Recursively creates a dataclass instance from a nested dictionary.""" 750 751 field_values = {} 752 753 if is_list_type(cls): 754 result = [] 755 for item in data: 756 result.append( 757 create_dataclass_from_dict( 758 typing.get_args(cls)[0], item, indent + 2 759 ) 760 ) 761 return result 762 763 if is_primitive(data): 764 return data 765 766 for field in fields(cls): 767 field_value = data.get(field.name) 768 if field_value is None: 769 field_value = data.get(field.name.replace('_', '-')) 770 771 assert field_value is not None 772 773 # We need to check if the field is a List, dictionary, or another 774 # dataclass. If it is, recurse. 775 if is_list_type(field.type): 776 item_type = typing.get_args(field.type)[0] 777 field_value = [ 778 create_dataclass_from_dict(item_type, item, indent + 2) 779 for item in field_value 780 ] 781 elif dict in field.type.__mro__: 782 # We might not have types specified in the dataclass 783 value_types = typing.get_args(field.type) 784 if len(value_types) != 0: 785 value_type = value_types[1] 786 field_value = { 787 key: create_dataclass_from_dict(value_type, val, indent + 2) 788 for key, val in field_value.items() 789 } 790 elif is_dataclass(field.type): 791 field_value = create_dataclass_from_dict( 792 field.type, field_value, indent + 2 793 ) 794 795 field_values[field.name] = field_value 796 797 return cls(**field_values) 798 799 800def main() -> None: 801 """ 802 Main entry point, this function will: 803 - Get CLI flags 804 - Read YAML from stdin 805 - Find all attribute, channel, trigger, and unit definitions 806 - Print header 807 """ 808 args = get_args() 809 yaml_input = yaml.safe_load(sys.stdin) 810 spec: InputSpec = create_dataclass_from_dict(InputSpec, yaml_input) 811 data = InputData(spec=spec) 812 813 if args.language == "cpp": 814 out = CppHeader( 815 using_zephyr=args.zephyr, 816 package=args.package, 817 attributes=list(data.all_attributes), 818 channels=list(data.all_channels), 819 triggers=list(data.all_triggers), 820 units=list(data.all_units.values()), 821 sensors=list(data.all_sensors), 822 ) 823 else: 824 raise ValueError(f"Invalid language selected: '{args.language}'") 825 print(out) 826 827 828def validate_package_arg(value: str) -> str: 829 """ 830 Validate that the package argument is a valid string 831 832 Args: 833 value: The package name 834 835 Returns: 836 The same value after being validated. 837 """ 838 if value is None or value == "": 839 return value 840 if not re.match(r"[a-zA-Z_$][\w$]*(\.[a-zA-Z_$][\w$]*)*", value): 841 raise argparse.ArgumentError( 842 argument=None, 843 message=f"Invalid string {value}. Must use alphanumeric values " 844 "separated by dots.", 845 ) 846 return value 847 848 849def get_args() -> Args: 850 """ 851 Get CLI arguments 852 853 Returns: 854 Typed arguments class instance 855 """ 856 parser = argparse.ArgumentParser() 857 parser.add_argument( 858 "--package", 859 "-pkg", 860 default="", 861 type=validate_package_arg, 862 help="Output package name separated by '.', example: com.google", 863 ) 864 parser.add_argument( 865 "--language", 866 type=str, 867 choices=["cpp"], 868 default="cpp", 869 ) 870 parser.add_argument( 871 "--zephyr", 872 action="store_true", 873 ) 874 args = parser.parse_args() 875 return Args( 876 package=args.package.split("."), 877 language=args.language, 878 zephyr=args.zephyr, 879 ) 880 881 882if __name__ == "__main__": 883 main() 884