diff --git a/roles/write-inventory/library/test_write_inventory.py b/roles/write-inventory/library/test_write_inventory.py
index 7951e84d3..a1bef81d5 100644
--- a/roles/write-inventory/library/test_write_inventory.py
+++ b/roles/write-inventory/library/test_write_inventory.py
@@ -56,6 +56,14 @@ xenial:
     region: GRA1
 """)
 
+GROUPS_INPUT = yaml.safe_load("""
+all: []
+ungrouped: []
+puppet:
+  - bionic
+  - xenial
+""")
+
 
 class TestWriteInventory(testtools.TestCase):
     def assertOutput(self, dest, ref):
@@ -67,10 +75,18 @@ class TestWriteInventory(testtools.TestCase):
         '''Test passing all variables'''
         dest = self.useFixture(fixtures.TempDir()).path
         dest = os.path.join(dest, 'out.yaml')
-        run(dest, INPUT, None, None)
+        run(dest, INPUT, GROUPS_INPUT, None, None)
 
         self.assertOutput(dest, {
             'all': {
+                'children': {
+                    'puppet': {
+                        'hosts': {
+                            'bionic': None,
+                            'xenial': None,
+                        },
+                    },
+                },
                 'hosts': {
                     'bionic': {
                         "ansible_connection": "ssh",
@@ -92,10 +108,18 @@ class TestWriteInventory(testtools.TestCase):
         '''Test incuding vars'''
         dest = self.useFixture(fixtures.TempDir()).path
         dest = os.path.join(dest, 'out.yaml')
-        run(dest, INPUT, ['ansible_host'], None)
+        run(dest, INPUT, GROUPS_INPUT, ['ansible_host'], None)
 
         self.assertOutput(dest, {
             'all': {
+                'children': {
+                    'puppet': {
+                        'hosts': {
+                            'bionic': None,
+                            'xenial': None,
+                        },
+                    },
+                },
                 'hosts': {
                     'bionic': {
                         "ansible_host": "104.130.217.77",
@@ -111,10 +135,18 @@ class TestWriteInventory(testtools.TestCase):
         '''Test passing all variables'''
         dest = self.useFixture(fixtures.TempDir()).path
         dest = os.path.join(dest, 'out.yaml')
-        run(dest, INPUT, None, ['ansible_user'])
+        run(dest, INPUT, GROUPS_INPUT, None, ['ansible_user'])
 
         self.assertOutput(dest, {
             'all': {
+                'children': {
+                    'puppet': {
+                        'hosts': {
+                            'bionic': None,
+                            'xenial': None,
+                        },
+                    },
+                },
                 'hosts': {
                     'bionic': {
                         "ansible_connection": "ssh",
diff --git a/roles/write-inventory/library/write_inventory.py b/roles/write-inventory/library/write_inventory.py
index 0707a6496..1fdb6a32f 100755
--- a/roles/write-inventory/library/write_inventory.py
+++ b/roles/write-inventory/library/write_inventory.py
@@ -28,9 +28,23 @@ VARS = [
 ]
 
 
-def run(dest, hostvars, include, exclude):
+def run(dest, hostvars, groups, include, exclude):
+    children = {}
+    for group, hostnames in groups.items():
+        if group == 'all' or group == 'ungrouped':
+            continue
+        children[group] = {}
+        children[group]['hosts'] = {}
+        for host in hostnames:
+            children[group]['hosts'][host] = None
+
     out_all = {}
-    out = {'all': {'hosts': out_all}}
+    out = {
+        'all': {
+            'children': children,
+            'hosts': out_all
+        }
+    }
     for host, hvars in hostvars.items():
         d = {}
         for v in VARS:
@@ -54,6 +68,7 @@ def ansible_main():
         argument_spec=dict(
             dest=dict(required=True, type='path'),
             hostvars=dict(required=True, type='raw'),
+            groups=dict(required=True, type='raw'),
             include_hostvars=dict(type='list'),
             exclude_hostvars=dict(type='list'),
         )
@@ -63,10 +78,11 @@ def ansible_main():
 
     dest = p.get('dest')
     hostvars = p.get('hostvars')
+    groups = p.get('groups')
     include = p.get('include_hostvars')
     exclude = p.get('exclude_hostvars')
 
-    run(dest, hostvars, include, exclude)
+    run(dest, hostvars, groups, include, exclude)
 
     module.exit_json(changed=True)
 
diff --git a/roles/write-inventory/tasks/main.yaml b/roles/write-inventory/tasks/main.yaml
index 52ae232dd..38b23834e 100644
--- a/roles/write-inventory/tasks/main.yaml
+++ b/roles/write-inventory/tasks/main.yaml
@@ -2,5 +2,6 @@
   write_inventory:
     dest: "{{ write_inventory_dest }}"
     hostvars: "{{ hostvars }}"
+    groups: "{{ groups }}"
     include_hostvars: "{{ write_inventory_include_hostvars | default(omit) }}"
     exclude_hostvars: "{{ write_inventory_exclude_hostvars | default(omit) }}"