xref: /aosp_15_r20/external/pigweed/pw_sensor/py/pw_sensor/constants_generator.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tooling to generate C++ constants from a yaml sensor definition."""
15
16import argparse
17from dataclasses import dataclass, fields, is_dataclass
18from collections.abc import Sequence
19import io
20import re
21import sys
22from typing import Type, List, Any
23import typing
24
25import yaml
26
27
28def kid_from_name(name: str) -> str:
29    """
30    Generate a const style ID name from a given name string. Example:
31      If name is "sample_rate", the ID would be kSampleRate
32
33    Args:
34      name: the name to convert to an ID
35    Returns:
36      C++ style 'k' prefixed camel cased ID
37    """
38    return "k" + ''.join(ele.title() for ele in re.split(r"[\s_\-\,]+", name))
39
40
41class Printable:
42    """Common printable object"""
43
44    def __init__(
45        self, item_id: str, name: str, description: str | None
46    ) -> None:
47        self.id: str = item_id
48        self.name: str = name
49        self.description: str | None = description
50
51    def __hash__(self) -> int:
52        return hash((self.id, self.name, self.description))
53
54    @property
55    def variable_name(self) -> str:
56        """Convert the 'id' to a constant variable name in C++."""
57        return kid_from_name(self.id)
58
59    def print(self, writer: io.TextIOWrapper) -> None:
60        """
61        Baseclass for a printable sensor object
62
63        Args:
64          writer: IO writer used to print values.
65        """
66        writer.write(
67            f"""
68/// @brief {self.name}
69"""
70        )
71        if self.description:
72            writer.write(
73                f"""///
74/// {self.description}
75"""
76            )
77
78
79@dataclass
80class UnitsSpec:
81    """Typing for the Units definition dictionary."""
82
83    name: str
84    symbol: str
85
86
87class Units(Printable):
88    """A single unit representation"""
89
90    symbol: str
91
92    def __init__(self, unit_id: str, definition: UnitsSpec) -> None:
93        """
94        Create a new unit object
95
96        Args:
97          unit_id: The ID of the unit
98          definition: A dictionary of the unit definition
99        """
100        super().__init__(
101            item_id=unit_id,
102            name=definition.name,
103            description=definition.name,
104        )
105        self.symbol: str = definition.symbol
106
107    def __hash__(self) -> int:
108        return hash((super().__hash__(), self.symbol))
109
110    def __eq__(self, value: object) -> bool:
111        if not isinstance(value, Units):
112            return False
113        return (
114            self.id == value.id
115            and self.name == value.name
116            and self.description == value.description
117            and self.symbol == value.symbol
118        )
119
120    def print(self, writer: io.TextIOWrapper) -> None:
121        """
122        Print header definition for this unit
123
124        Args:
125          writer: IO writer used to print values.
126        """
127        super().print(writer=writer)
128        writer.write(
129            f"""PW_SENSOR_UNIT_TYPE(
130    {super().variable_name},
131    "PW_SENSOR_UNITS_TYPE",
132    "{self.name}",
133    "{self.symbol}"
134);
135"""
136        )
137
138
139@dataclass
140class AttributeSpec:
141    """Typing for the Attribute definition dictionary."""
142
143    name: str
144    description: str
145
146
147class Attribute(Printable):
148    """A single attribute representation."""
149
150    def __init__(self, attr_id: str, definition: AttributeSpec) -> None:
151        super().__init__(
152            item_id=attr_id,
153            name=definition.name,
154            description=definition.description,
155        )
156
157    def print(self, writer: io.TextIOWrapper) -> None:
158        """
159        Print header definition for this attribute
160
161        Args:
162          writer: IO writer used to print values.
163        """
164        super().print(writer=writer)
165        writer.write(
166            f"""PW_SENSOR_ATTRIBUTE_TYPE(
167    {super().variable_name},
168    "PW_SENSOR_ATTRIBUTE_TYPE",
169    "{self.name}"
170);
171"""
172        )
173
174
175@dataclass
176class ChannelSpec:
177    """Typing for the Channel definition dictionary."""
178
179    name: str
180    description: str
181    units: str
182
183
184class Channel(Printable):
185    """A single channel representation."""
186
187    def __init__(
188        self, channel_id: str, definition: ChannelSpec, units: dict[str, Units]
189    ) -> None:
190        super().__init__(
191            item_id=channel_id,
192            name=definition.name,
193            description=definition.description,
194        )
195        self.units: Units = units[definition.units]
196
197    def __hash__(self) -> int:
198        return hash((super().__hash__(), self.units))
199
200    def __eq__(self, value: object) -> bool:
201        if not isinstance(value, Channel):
202            return False
203        return (
204            self.id == value.id
205            and self.name == value.name
206            and self.description == value.description
207            and self.units == value.units
208        )
209
210    def print(self, writer: io.TextIOWrapper) -> None:
211        """
212        Print header definition for this channel
213
214        Args:
215          writer: IO writer used to print values.
216        """
217        super().print(writer=writer)
218        writer.write(
219            f"""
220PW_SENSOR_MEASUREMENT_TYPE({super().variable_name},
221                           "PW_SENSOR_MEASUREMENT_TYPE",
222                           "{self.name}",
223                           ::pw::sensor::units::{self.units.variable_name}
224);"""
225        )
226
227
228@dataclass
229class TriggerSpec:
230    """Typing for the Trigger definition dictionary."""
231
232    name: str
233    description: str
234
235
236class Trigger(Printable):
237    """A single trigger representation."""
238
239    def __init__(self, trigger_id: str, definition: TriggerSpec) -> None:
240        super().__init__(
241            item_id=trigger_id,
242            name=definition.name,
243            description=definition.description,
244        )
245
246    def print(self, writer: io.TextIOWrapper) -> None:
247        """
248        Print header definition for this trigger
249
250        Args:
251          writer: IO writer used to print values.
252        """
253        super().print(writer=writer)
254        writer.write(
255            f"""PW_SENSOR_TRIGGER_TYPE(
256    {super().variable_name},
257    "PW_SENSOR_TRIGGER_TYPE",
258    "{self.name}"
259);
260"""
261        )
262
263
264@dataclass
265class SensorAttributeSpec:
266    """Typing for the SensorAttribute definition dictionary."""
267
268    channel: str
269    attribute: str
270    units: str
271
272
273class SensorAttribute(Printable):
274    """An attribute instance belonging to a sensor"""
275
276    @staticmethod
277    def id_from_definition(definition: SensorAttributeSpec) -> str:
278        """
279        Get a unique ID for the channel/attribute pair (not sensor specific)
280
281        Args:
282          definition: A dictionary of the attribute definition
283        Returns:
284          String representation for the channel/attribute pair
285        """
286        return f"{definition.channel}-{definition.attribute}"
287
288    @staticmethod
289    def name_from_definition(definition: SensorAttributeSpec) -> str:
290        """
291        Get a unique name for the channel/attribute pair (not sensor specific)
292
293        Args:
294          definition: A dictionary of the attribute definition
295        Returns:
296          String representation of the human readable name for the
297          channel/attribute pair.
298        """
299        return f"{definition.channel}'s {definition.attribute} attribute"
300
301    @staticmethod
302    def description_from_definition(definition: SensorAttributeSpec) -> str:
303        """
304        Get the description for the channel/attribute pair (not sensor specific)
305
306        Args:
307          definition: A dictionary of the attribute definition
308        Returns:
309          A description string for the channel/attribute pair.
310        """
311        return (
312            f"Allow the configuration of the {definition.channel}'s "
313            + f"{definition.attribute} attribute"
314        )
315
316    def __init__(self, definition: SensorAttributeSpec) -> None:
317        super().__init__(
318            item_id=SensorAttribute.id_from_definition(definition=definition),
319            name=SensorAttribute.name_from_definition(definition=definition),
320            description=SensorAttribute.description_from_definition(
321                definition=definition
322            ),
323        )
324        self.attribute: str = definition.attribute
325        self.channel: str = definition.channel
326        self.units: str = definition.units
327
328    def __hash__(self) -> int:
329        return hash(
330            (super().__hash__(), self.attribute, self.channel, self.units)
331        )
332
333    def print(self, writer: io.TextIOWrapper) -> None:
334        super().print(writer)
335        writer.write(
336            f"""
337PW_SENSOR_ATTRIBUTE_INSTANCE({self.variable_name},
338                             channels::{kid_from_name(self.channel)},
339                             attributes::{kid_from_name(self.attribute)},
340                             units::{kid_from_name(self.units)});
341"""
342        )
343
344
345@dataclass
346class CompatibleSpec:
347    """Typing for the Compatible dictionary."""
348
349    org: str
350    part: str
351
352
353@dataclass
354class SensorSpec:
355    """Typing for the Sensor definition dictionary."""
356
357    description: str
358    compatible: CompatibleSpec
359    supported_buses: List[str]
360    attributes: List[SensorAttributeSpec]
361    channels: dict[str, List[ChannelSpec]]
362    triggers: List[Any]
363    extras: dict[str, Any]
364
365
366class Sensor(Printable):
367    """Represent a single sensor type instance"""
368
369    @staticmethod
370    def sensor_id_to_name(sensor_id: str) -> str:
371        """
372        Convert a sensor ID to a human readable name
373
374        Args:
375          sensor_id: The ID of the sensor
376        Returns:
377          Human readable name based on the ID
378        """
379        return sensor_id.replace(',', ' ')
380
381    def __init__(self, item_id: str, definition: SensorSpec) -> None:
382        super().__init__(
383            item_id=item_id,
384            name=Sensor.sensor_id_to_name(item_id),
385            description=definition.description,
386        )
387        self.compatible_org: str = definition.compatible.org
388        self.compatible_part: str = definition.compatible.part
389        self.chan_count: int = len(definition.channels)
390        self.attr_count: int = len(definition.attributes)
391        self.trig_count: int = len(definition.triggers)
392        self.attributes: Sequence[SensorAttribute] = [
393            SensorAttribute(definition=spec) for spec in definition.attributes
394        ]
395
396    @property
397    def namespace(self) -> str:
398        """
399        The namespace which owns the sensor (the org from the compatible)
400
401        Returns:
402          The C++ style namespace name of the org.
403        """
404        return self.compatible_org.replace("-", "_")
405
406    def __hash__(self) -> int:
407        return hash(
408            (
409                super().__hash__(),
410                self.compatible_org,
411                self.compatible_part,
412                self.chan_count,
413                self.attr_count,
414                self.trig_count,
415            )
416        )
417
418    def print(self, writer: io.TextIOWrapper) -> None:
419        """
420        Print header definition for this trigger
421
422        Args:
423          writer: IO writer used to print values.
424        """
425        writer.write(
426            f"""
427namespace {self.namespace} {{
428class {self.compatible_part.upper()}
429  : public pw::sensor::zephyr::ZephyrSensor<{len(self.attributes)}> {{
430 public:
431  {self.compatible_part.upper()}(const struct device* dev)
432      : pw::sensor::zephyr::ZephyrSensor<{len(self.attributes)}>(
433            dev,
434            {{"""
435        )
436        for attribute in self.attributes:
437            writer.write(
438                f"""
439             Attribute::Build<{attribute.variable_name}>(),"""
440            )
441        writer.write(
442            f"""
443            }}) {{}}
444}};
445}}  // namespace {self.namespace}
446"""
447        )
448
449
450@dataclass
451class Args:
452    """CLI arguments"""
453
454    package: Sequence[str]
455    language: str
456    zephyr: bool
457
458
459class CppHeader:
460    """Generator for a C++ header"""
461
462    def __init__(
463        self,
464        using_zephyr: bool,
465        package: Sequence[str],
466        attributes: Sequence[Attribute],
467        channels: Sequence[Channel],
468        triggers: Sequence[Trigger],
469        units: Sequence[Units],
470        sensors: Sequence[Sensor],
471    ) -> None:
472        """
473        Args:
474          package: The package name used in the output. In C++ we'll convert
475            this to a namespace.
476          attributes: A sequence of attributes which will be exposed in the
477            'attributes' namespace
478          channels: A sequence of channels which will be exposed in the
479            'channels' namespace
480          triggers: A sequence of triggers which will be exposed in the
481            'triggers' namespace
482          units: A sequence of units which should be exposed in the 'units'
483            namespace
484        """
485        self._using_zephyr: bool = using_zephyr
486        self._package: str = '::'.join(package)
487        self._attributes: Sequence[Attribute] = attributes
488        self._channels: Sequence[Channel] = channels
489        self._triggers: Sequence[Trigger] = triggers
490        self._units: Sequence[Units] = units
491        self._sensors: Sequence[Sensor] = sensors
492        self._sensor_attributes: set[SensorAttribute] = set()
493        for sensor in sensors:
494            self._sensor_attributes.update(sensor.attributes)
495
496    def __str__(self) -> str:
497        writer = io.StringIO()
498        self._print_header(writer=writer)
499        self._print_constants(writer=writer)
500        self._print_footer(writer=writer)
501        return writer.getvalue()
502
503    def _print_header(self, writer: io.TextIOWrapper) -> None:
504        """
505        Print the top part of the .h file (pragma, includes, and namespace)
506
507        Args:
508          writer: Where to write the text to
509        """
510        writer.write(
511            "/* Auto-generated file, do not edit */\n"
512            "#pragma once\n"
513            "\n"
514            "#include \"pw_sensor/types.h\"\n"
515        )
516        if self._package:
517            writer.write(f"namespace {self._package} {{\n\n")
518
519    def _print_constants(self, writer: io.TextIOWrapper) -> None:
520        """
521        Print the constants definitions from self._attributes, self._channels,
522        and self._trigger
523
524        Args:
525            writer: Where to write the text
526        """
527
528        self._print_in_namespace(
529            namespace="units", printables=self._units, writer=writer
530        )
531        self._print_in_namespace(
532            namespace="attributes", printables=self._attributes, writer=writer
533        )
534        self._print_in_namespace(
535            namespace="channels", printables=self._channels, writer=writer
536        )
537        self._print_in_namespace(
538            namespace="triggers", printables=self._triggers, writer=writer
539        )
540        for sensor_attribute in self._sensor_attributes:
541            sensor_attribute.print(writer=writer)
542
543    @staticmethod
544    def _print_in_namespace(
545        namespace: str,
546        printables: Sequence[Printable],
547        writer: io.TextIOWrapper,
548    ) -> None:
549        """
550        Print constants definitions wrapped in a namespace
551        Args:
552          namespace: The namespace to use
553          printables: A sequence of printable objects
554          writer: Where to write the text
555        """
556        writer.write(f"\nnamespace {namespace} {{\n")
557        for printable in printables:
558            printable.print(writer=writer)
559        writer.write(f"\n}}  // namespace {namespace}\n")
560
561    def _print_footer(self, writer: io.TextIOWrapper) -> None:
562        """
563        Write the bottom part of the .h file (closing namespace)
564
565        Args:
566            writer: Where to write the text
567        """
568        if self._package:
569            writer.write(f"\n}}  // namespace {self._package}")
570
571        if self._using_zephyr:
572            self._print_zephyr_mapping(writer=writer)
573
574    def _print_zephyr_mapping(self, writer: io.TextIOWrapper) -> None:
575        """
576        Generate Zephyr type maps for channels, attributes, and triggers.
577
578        Args:
579            writer: Where to write the text
580        """
581        writer.write(
582            f"""
583#include <zephyr/generated/sensor_constants.h>
584#include \"pw_containers/flat_map.h\"
585
586namespace pw::sensor::zephyr {{
587
588class ZephyrAttributeMap
589    : public pw::containers::FlatMap<uint32_t, uint32_t,
590                                     {len(self._attributes)}> {{
591 public:
592  constexpr ZephyrAttributeMap()
593      : pw::containers::FlatMap<uint32_t, uint32_t,
594                                {len(self._attributes)}>({{{{"""
595        )
596        for attribute in self._attributes:
597            attribute_type = (
598                f"{self._package}::attributes::"
599                + f"{attribute.variable_name}::kAttributeType"
600            )
601            writer.write(
602                f"""
603            {{{attribute_type},
604             SENSOR_ATTR_{attribute.id.upper()}}},"""
605            )
606        writer.write("\n      }}) {}\n};")
607        writer.write(
608            f"""
609
610class ZephyrChannelMap
611    : public pw::containers::FlatMap<uint32_t, uint32_t,
612                                     {len(self._channels)}> {{
613 public:
614  constexpr ZephyrChannelMap()
615      : pw::containers::FlatMap<uint32_t, uint32_t,
616                                {len(self._channels)}>({{{{"""
617        )
618        for channel in self._channels:
619            measurement_name = (
620                f"{self._package}::channels::"
621                + f"{channel.variable_name}::kMeasurementName"
622            )
623            writer.write(
624                f"""
625            {{{measurement_name},
626             SENSOR_CHAN_{channel.id.upper()}}},"""
627            )
628        writer.write(
629            """
630      }}) {}
631};
632
633extern ZephyrAttributeMap kAttributeMap;
634extern ZephyrChannelMap kChannelMap;
635
636}  // namespace pw::sensor::zephyr
637
638#include "pw_sensor_zephyr/sensor.h"
639namespace pw::sensor {
640"""
641        )
642        for sensor in self._sensors:
643            sensor.print(writer=writer)
644        writer.write(
645            """
646}  // namespace pw::sensor
647"""
648        )
649
650
651@dataclass
652class InputSpec:
653    """Typing for the InputData spec dictionary"""
654
655    units: dict[str, UnitsSpec]
656    attributes: dict[str, AttributeSpec]
657    channels: dict[str, ChannelSpec]
658    triggers: dict[str, TriggerSpec]
659    sensors: dict[str, SensorSpec]
660
661
662class InputData:
663    """
664    Wrapper class for all the input data parsed out into: Units, Attribute,
665    Channel, Trigger, and Sensor types (or sub-types).
666    """
667
668    def __init__(
669        self,
670        spec: InputSpec,
671        units_type: Type[Units] = Units,
672        attribute_type: Type[Attribute] = Attribute,
673        channel_type: Type[Channel] = Channel,
674        trigger_type: Type[Trigger] = Trigger,
675        sensor_type: Type[Sensor] = Sensor,
676    ) -> None:
677        """
678        Parse the input spec and create all the input data types.
679
680        Args:
681          spec: The input spec dictionary
682          units_type: The type to use for units
683          attribute_type: The type to use for attributes
684          channel_type: The type to use for channels
685          trigger_type: The type to use for triggers
686          sensor_type: The type to use for sensors
687        """
688        self.all_attributes: set[Attribute] = set()
689        self.all_channels: set[Channel] = set()
690        self.all_triggers: set[Trigger] = set()
691        self.all_units: dict[str, Units] = {}
692        self.all_sensors: set[Sensor] = set()
693        for units_id, units_spec in spec.units.items():
694            units = units_type(unit_id=units_id, definition=units_spec)
695            assert units not in self.all_units.values()
696            self.all_units[units_id] = units
697        for attribute_id, attribute_spec in spec.attributes.items():
698            attribute = attribute_type(
699                attr_id=attribute_id, definition=attribute_spec
700            )
701            assert attribute not in self.all_attributes
702            self.all_attributes.add(attribute)
703        for channel_id, channel_spec in spec.channels.items():
704            channel = channel_type(
705                channel_id=channel_id,
706                definition=channel_spec,
707                units=self.all_units,
708            )
709            assert channel not in self.all_channels
710            self.all_channels.add(channel)
711        for trigger_id, trigger_spec in spec.triggers.items():
712            trigger = trigger_type(
713                trigger_id=trigger_id, definition=trigger_spec
714            )
715            assert trigger not in self.all_triggers
716            self.all_triggers.add(trigger)
717        for sensor_id, sensor_spec in spec.sensors.items():
718            sensor = sensor_type(item_id=sensor_id, definition=sensor_spec)
719            assert sensor not in self.all_sensors
720            self.all_sensors.add(sensor)
721
722
723def is_list_type(t) -> bool:
724    """
725    Checks if the given type `t` is either a list or typing.List.
726
727    Args:
728        t: The type to check.
729    Returns:
730        True if `t` is a list type, False otherwise.
731    """
732    origin = typing.get_origin(t)
733    return origin is list or (origin is list and typing.get_args(t) == ())
734
735
736def is_primitive(value):
737    """Checks if the given value is of a primitive type.
738
739    Args:
740        value: The value to check.
741
742    Returns:
743        True if the value is of a primitive type, False otherwise.
744    """
745    return isinstance(value, (int, float, complex, str, bool))
746
747
748def create_dataclass_from_dict(cls, data, indent: int = 0):
749    """Recursively creates a dataclass instance from a nested dictionary."""
750
751    field_values = {}
752
753    if is_list_type(cls):
754        result = []
755        for item in data:
756            result.append(
757                create_dataclass_from_dict(
758                    typing.get_args(cls)[0], item, indent + 2
759                )
760            )
761        return result
762
763    if is_primitive(data):
764        return data
765
766    for field in fields(cls):
767        field_value = data.get(field.name)
768        if field_value is None:
769            field_value = data.get(field.name.replace('_', '-'))
770
771        assert field_value is not None
772
773        # We need to check if the field is a List, dictionary, or another
774        # dataclass. If it is, recurse.
775        if is_list_type(field.type):
776            item_type = typing.get_args(field.type)[0]
777            field_value = [
778                create_dataclass_from_dict(item_type, item, indent + 2)
779                for item in field_value
780            ]
781        elif dict in field.type.__mro__:
782            # We might not have types specified in the dataclass
783            value_types = typing.get_args(field.type)
784            if len(value_types) != 0:
785                value_type = value_types[1]
786                field_value = {
787                    key: create_dataclass_from_dict(value_type, val, indent + 2)
788                    for key, val in field_value.items()
789                }
790        elif is_dataclass(field.type):
791            field_value = create_dataclass_from_dict(
792                field.type, field_value, indent + 2
793            )
794
795        field_values[field.name] = field_value
796
797    return cls(**field_values)
798
799
800def main() -> None:
801    """
802    Main entry point, this function will:
803    - Get CLI flags
804    - Read YAML from stdin
805    - Find all attribute, channel, trigger, and unit definitions
806    - Print header
807    """
808    args = get_args()
809    yaml_input = yaml.safe_load(sys.stdin)
810    spec: InputSpec = create_dataclass_from_dict(InputSpec, yaml_input)
811    data = InputData(spec=spec)
812
813    if args.language == "cpp":
814        out = CppHeader(
815            using_zephyr=args.zephyr,
816            package=args.package,
817            attributes=list(data.all_attributes),
818            channels=list(data.all_channels),
819            triggers=list(data.all_triggers),
820            units=list(data.all_units.values()),
821            sensors=list(data.all_sensors),
822        )
823    else:
824        raise ValueError(f"Invalid language selected: '{args.language}'")
825    print(out)
826
827
828def validate_package_arg(value: str) -> str:
829    """
830    Validate that the package argument is a valid string
831
832    Args:
833      value: The package name
834
835    Returns:
836      The same value after being validated.
837    """
838    if value is None or value == "":
839        return value
840    if not re.match(r"[a-zA-Z_$][\w$]*(\.[a-zA-Z_$][\w$]*)*", value):
841        raise argparse.ArgumentError(
842            argument=None,
843            message=f"Invalid string {value}. Must use alphanumeric values "
844            "separated by dots.",
845        )
846    return value
847
848
849def get_args() -> Args:
850    """
851    Get CLI arguments
852
853    Returns:
854      Typed arguments class instance
855    """
856    parser = argparse.ArgumentParser()
857    parser.add_argument(
858        "--package",
859        "-pkg",
860        default="",
861        type=validate_package_arg,
862        help="Output package name separated by '.', example: com.google",
863    )
864    parser.add_argument(
865        "--language",
866        type=str,
867        choices=["cpp"],
868        default="cpp",
869    )
870    parser.add_argument(
871        "--zephyr",
872        action="store_true",
873    )
874    args = parser.parse_args()
875    return Args(
876        package=args.package.split("."),
877        language=args.language,
878        zephyr=args.zephyr,
879    )
880
881
882if __name__ == "__main__":
883    main()
884